# Import libraries

In [None]:
import os
import shutil
import numpy as np
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator, rdmolops
from rdkit.Avalon import pyAvalonTools
import deepchem as dc
import torch
from torch.nn import Linear, BatchNorm1d, ModuleList
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import TransformerConv, global_mean_pool, summary
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt

In [None]:
# Set random seed for reproducibility
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Dataset class

In [None]:
class ChromophoresDataset(Dataset):
    def __init__(self, data_file_path, dataset_name, target_column, gcn_featurizer_name, mol_featurizer_name, transform=None, pre_transform=None, pre_filter=None):
        self.dataset_name = dataset_name
        self.data_file_name = os.path.basename(data_file_path)        
        self.target_column = target_column
        self.output_scaler = MinMaxScaler(feature_range=(0, 1))
        self.gcn_featurizer_name = gcn_featurizer_name
        self.mol_featurizer_name = mol_featurizer_name
        self.mol_features_scaler = MinMaxScaler(feature_range=(0, 1))
        os.makedirs('./Datasets', exist_ok=True)
        root = './Datasets/' + dataset_name # Where the dataset should be stored. This folder is split into 'raw' and 'processed'.
        os.makedirs(root, exist_ok=True)
        os.makedirs(root + '/raw', exist_ok=True)
        shutil.copy2(data_file_path, root + '/raw')
        super(ChromophoresDataset, self).__init__(root, transform, pre_transform, pre_filter)
        
    @property
    def raw_file_names(self): # If this file exists in 'raw', the download is not triggered.
        return self.data_file_name

    @property
    def processed_file_names(self):
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()
        return [f'data_{i}.pt' for i in list(self.data.index)]
    
    def process(self):
        # Read the data file and fit the output scaler
        self.data = pd.read_csv('./Datasets/' + self.dataset_name + '/raw/' + self.data_file_name)
        self.output_scaler.fit(self.data[self.target_column].to_numpy().reshape(-1, 1))

        if len(os.listdir(self.processed_dir)) > 2: # If these files are found in 'processed', processing is skipped.
            mol_features_list = np.loadtxt('Datasets/' + self.dataset_name + '/raw/mol_features.csv', delimiter=',')
            self.mol_features_scaler.fit(np.array(mol_features_list))

            return
        
        # Initialize the featurizers
        if self.gcn_featurizer_name == 'MGCF':
            gcn_featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True, use_chirality=True, use_partial_charge=True)
        elif self.gcn_featurizer_name == 'PMGF':
            gcn_featurizer = dc.feat.PagtnMolGraphFeaturizer(max_length=5)
        else: #self.gcn_featurizer_name == 'DMPNNF'
            gcn_featurizer = dc.feat.DMPNNFeaturizer(is_adding_hs=False)
        
        if (self.mol_featurizer_name == 'Mordred'):
            mol_featurizer = dc.feat.MordredDescriptors(ignore_3D=True)
        elif (self.mol_featurizer_name == 'MorganFP'):
            mol_featurizer = dc.feat.CircularFingerprint(size=2048, radius=3)
        else:
            mol_featurizer = None # for other RDKit's fingerprints

        mol_features_list = []
        graph_index = 0
        for index, row in tqdm(self.data.iterrows(), total=self.data.shape[0]):
            try:
                # Graph featurization for chromophores
                chromophore_mol_obj = Chem.MolFromSmiles(row['Chromophore'])
                chromophore_mol_obj = Chem.AddHs(chromophore_mol_obj)
                gcn_features = gcn_featurizer.featurize(chromophore_mol_obj)
                node_feats = torch.tensor(np.array(gcn_features[0].node_features))
                self.n_gcn_features = node_feats.shape[1]
                edge_index = torch.tensor(np.array(gcn_features[0].edge_index))
                edge_attr = torch.tensor(np.array(gcn_features[0].edge_features))

                # Molecular fingerprint featurization for chromophores
                if mol_featurizer is not None: # Deepchem's Mordred descriptors or Morgan fingerprint
                    mol_features = mol_featurizer.featurize(row['Chromophore'])
                else: # RDKit's fingerprints
                    if self.mol_featurizer_name == 'AvalonFP':
                        mol_features = np.array(pyAvalonTools.GetAvalonFP(chromophore_mol_obj, nBits=2048)).reshape(1, -1)
                    elif self.mol_featurizer_name == 'APFP':
                        mol_features = rdFingerprintGenerator.GetAtomPairGenerator(fpSize=2048).GetFingerprintAsNumPy(chromophore_mol_obj).reshape(1, -1)
                    elif self.mol_featurizer_name == 'TTFP':
                        mol_features = rdFingerprintGenerator.GetTopologicalTorsionGenerator(fpSize=2048).GetFingerprintAsNumPy(chromophore_mol_obj).reshape(1, -1)
                    elif self.mol_featurizer_name == 'LayeredFP':
                        mol_features = np.array(rdmolops.LayeredFingerprint(chromophore_mol_obj, fpSize=2048)).reshape(1, -1)
                    elif self.mol_featurizer_name == 'PatternFP':
                        mol_features = np.array(rdmolops.LayeredFingerprint(chromophore_mol_obj, fpSize=2048)).reshape(1, -1)
                    else: # self.mol_featurizer_name == 'RDKitFP':
                        mol_features = rdFingerprintGenerator.GetRDKitFPGenerator(fpSize=2048).GetFingerprintAsNumPy(chromophore_mol_obj).reshape(1, -1)
                mol_features_list.append(mol_features.reshape(-1))
                mol_features = torch.tensor(mol_features)
                self.n_mol_features = mol_features.shape[1]

                # Target
                output = self.get_output(row[self.target_column])

                # Save data to drive
                data = Data(x=node_feats, 
                            edge_index=edge_index,
                            edge_attr=edge_attr,
                            y=output,
                            smiles=row['Chromophore'],
                            solvent_smiles=row['Solvent'],
                            mol_features=mol_features
                            )
                
                torch.save(data, os.path.join(self.processed_dir, 
                                 f'{self.dataset_name}_{graph_index}.pt'))
                
                graph_index += 1
            except Exception:
                continue
        
        # Save the mol_features list to a CSV file and fit the feature scaler
        mol_features_arr = np.array(mol_features_list)
        np.savetxt('Datasets/' + self.dataset_name + '/raw/mol_features.csv', mol_features_arr, delimiter=',')
        self.mol_features_scaler.fit(mol_features_arr)

    def get_output(self, output):
        output = self.output_scaler.transform(np.array([output]).reshape(-1, 1))
        return torch.tensor(output)

    def len(self):
        _, _, files = next(os.walk('./Datasets/' + self.dataset_name + '/processed'))
        return len(files) - 2

    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, f'{self.dataset_name}_{idx}.pt'), weights_only=False)   
        return data
    
    @property
    def num_node_features(self):
        if (self.gcn_featurizer_name == 'MGCF'):
            return 33
        elif (self.gcn_featurizer_name == 'PMGF'):
            return 94
        elif (self.gcn_featurizer_name == 'DMPNNF'):
            return 133
        
    @property
    def num_edge_features(self):
        if (self.gcn_featurizer_name == 'MGCF'):
            return 11
        elif (self.gcn_featurizer_name == 'PMGF'):
            return 42
        elif (self.gcn_featurizer_name == 'DMPNNF'):
            return 14
    
    @property
    def num_mol_features(self):
        if (self.mol_featurizer_name == 'Mordred'):
            return 1613
        elif (self.mol_featurizer_name == 'MorganFP'):
            return 2048
        elif (self.mol_featurizer_name == 'AvalonFP'):
            return 2048
        elif (self.mol_featurizer_name == 'APFP'):
            return 2048
        elif (self.mol_featurizer_name == 'TTFP'):
            return 2048
        elif (self.mol_featurizer_name == 'LayeredFP'):
            return 2048
        elif (self.mol_featurizer_name == 'PatternFP'):
            return 2048
        elif (self.mol_featurizer_name == 'RDKitFP'):
            return 2048

# Load data

In [None]:
# Path to the dataset file
data_file_path = './Datasets/Abs_unique.csv'   # Options: 'Abs_unique.csv', 'Abs_with_solvent.csv', 'Ems_unique.csv', 'Ems_with_solvent.csv'

# Choose molecular fingerprint and graph featurizer
gcn_featurizer_name = 'MGCF'  # Options: 'MGCF', 'PMGF', 'DMPNNF'
mol_featurizer_name = 'AvalonFP'  # Options: 'Mordred', 'MorganFP', 'AvalonFP', 'APFP', 'TTFP', 'LayeredFP', 'PatternFP', 'RDKitFP'

# Define the dataset name and target column
target_column = 'ABS'   # Options: 'ABS', 'EMS'
chromophores_dataset_name = f'Abs_unique_{gcn_featurizer_name}_{mol_featurizer_name}'

# Cownload and process the dataset of chromophores
chromophores_dataset = ChromophoresDataset(data_file_path, chromophores_dataset_name, target_column, gcn_featurizer_name, mol_featurizer_name)

In [None]:
# Train-validation-test splitting
train_size = int(len(chromophores_dataset) * 0.6)
val_size = int(len(chromophores_dataset) * 0.2)
test_size = len(chromophores_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(chromophores_dataset, [train_size, val_size, test_size])

In [None]:
# DataLoaders
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

# Hybrid GNN model

In [None]:
# GNN model definition
class HybridGNNModel(torch.nn.Module):
    def __init__(self, n_gcn_inputs, n_gcn_hiddens, n_gcn_layers, n_gcn_heads, n_gcn_outputs, edge_dim,
                 n_mlp_inputs, n_mlp_hiddens, n_mlp_layers, n_mlp_outputs,
                 n_predictor_hiddens, n_predictor_layers):
        super(HybridGNNModel, self).__init__()
        
        # Check output size
        if n_gcn_outputs == 0 and n_mlp_outputs == 0:
            raise Exception("The total output size of GCN and MLP modules cannot be 0!")
        
        # GCN layers (can be modified to use different GCN architectures)
        self.gcn = torch.nn.ModuleList()
        if n_gcn_outputs > 0:
            self.gcn.append(TransformerConv(n_gcn_inputs, n_gcn_hiddens, n_gcn_heads, edge_dim=edge_dim))
            self.gcn.append(Linear(n_gcn_hiddens*n_gcn_heads, n_gcn_hiddens))
            self.gcn.append(BatchNorm1d(n_gcn_hiddens))
            for i in range(1, n_gcn_layers):
                self.gcn.append(TransformerConv(n_gcn_hiddens, n_gcn_hiddens, n_gcn_heads, edge_dim=edge_dim))
                if i != n_gcn_layers - 1:
                    self.gcn.append(Linear(n_gcn_hiddens*n_gcn_heads, n_gcn_hiddens))
                    self.gcn.append(BatchNorm1d(n_gcn_hiddens))
                else:
                    self.gcn.append(Linear(n_gcn_hiddens*n_gcn_heads, n_gcn_outputs))
                    self.gcn.append(BatchNorm1d(n_gcn_outputs))
        
        # MLP layers
        self.mlp = ModuleList()
        if n_mlp_outputs > 0:
            self.mlp.append(Linear(n_mlp_inputs, n_mlp_hiddens))
            for i in range(1, n_mlp_layers):
                self.mlp.append(Linear(n_mlp_hiddens, n_mlp_hiddens))
            self.mlp.append(Linear(n_mlp_hiddens, n_mlp_outputs))
        
        # Predictor layers
        self.predictor = ModuleList()
        if n_predictor_layers > 0:
            self.predictor.append(Linear(n_gcn_outputs + n_mlp_outputs, n_predictor_hiddens))
            for i in range(1, n_predictor_layers):
                self.predictor.append(Linear(n_predictor_hiddens, n_predictor_hiddens))
        
        if n_predictor_layers > 0:
            self.out = Linear(n_predictor_hiddens, 1)
        else:
            self.out = Linear(n_gcn_outputs + n_mlp_outputs, 1)
        
    def forward(self, x, edge_index, edge_attr, batch_index, mol_features):
        if len(self.gcn) > 0:
            for i, layer in enumerate(self.gcn):
                if i == 0:
                    h1 = layer(x, edge_index, edge_attr)
                elif i % 3 == 0: # GCN layer
                    h1 = layer(h1, edge_index, edge_attr)
                elif i % 3 == 1: # Linear layer
                    h1 = torch.relu(layer(h1)) 
                else: # BatchNorm1d layer
                    h1 = layer(h1)
            h1 = global_mean_pool(h1, batch_index)
        else:
            h1 = None

        if len(self.mlp) > 0:
            h2 = mol_features
            for i, linear in enumerate(self.mlp):
                h2 = torch.relu(linear(h2))
        else:
            h2 = None

        if h1 != None and h2 != None:
            h = torch.cat((h1, h2), dim=1)
        elif h2 != None:
            h = h2
        elif h1 != None:
            h = h1
        
        if len(self.predictor) > 0:
            for i, linear in enumerate(self.predictor):
                h = torch.relu(linear(h))
        
        return self.out(h)

In [None]:
# Initialize the model with the parameters
gcn_n_inputs = chromophores_dataset.num_node_features
gcn_n_hiddens = 128
gcn_n_layers = 3
gcn_n_heads = 3
gcn_n_outputs = 50
edge_dim = chromophores_dataset.num_edge_features
mlp_n_inputs = chromophores_dataset.num_mol_features
mlp_n_hiddens = 128
mlp_n_layers = 3
mlp_n_outputs = 50
predictor_n_hiddens = 128
predictor_n_layers = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HybridGNNModel(gcn_n_inputs, gcn_n_hiddens, gcn_n_layers, gcn_n_heads, gcn_n_outputs, edge_dim,
                            mlp_n_inputs, mlp_n_hiddens, mlp_n_layers, mlp_n_outputs,
                            predictor_n_hiddens, predictor_n_layers).to(device)

In [None]:
# Load the molecular fingerprint and fit the scaler
mol_features_arr = np.loadtxt('./Datasets/' + chromophores_dataset.dataset_name + '/raw/mol_features.csv', delimiter=",")
mol_features_scaler = MinMaxScaler(feature_range=(0, 1))
mol_features_scaler.fit(mol_features_arr)

In [None]:
# Print model summary
graph_data = chromophores_dataset[0]
x = graph_data.x.float().to(device)
edge_index = graph_data.edge_index.long().to(device)
edge_attr = graph_data.edge_attr.float().to(device)
batch_index = torch.tensor([0]).to(device)
mol_features = graph_data.mol_features.numpy()
mol_features_scaled = mol_features_scaler.transform(mol_features)
mol_features_scaled = torch.tensor(mol_features_scaled).float().to(device)

print(summary(model, x, edge_index, edge_attr, batch_index, mol_features_scaled))

# Training

In [None]:
# Initialize optimizer, scheduler, and loss function
train_losses = []
val_losses = []
trained_epochs = 0
optimizer = optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
loss_fn = torch.nn.MSELoss()

In [None]:
# Training function
def train(dataloader):
    model.train()
    total_loss = 0
    for batch_index, batch in enumerate(dataloader, 1):
        batch = batch.to(device)
        x = batch.x.float()
        edge_index = batch.edge_index.long()
        edge_attr = batch.edge_attr.float()
        mol_scaled = torch.tensor(
            mol_features_scaler.transform(batch.mol_features.cpu()),
            dtype=torch.float32, device=device
        )
        y = batch.y.float()
        output = model(x, edge_index, edge_attr, batch.batch, mol_scaled)
        loss = loss_fn(output, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        total_loss += torch.sqrt(loss).item()
    return total_loss / batch_index

In [None]:
# Validation function
@torch.no_grad()
def validation(dataloader):
    model.eval()
    total_loss = 0
    for batch_index, batch in enumerate(dataloader, 1):
        batch = batch.to(device)
        x = batch.x.float()
        edge_index = batch.edge_index.long()
        edge_attr = batch.edge_attr.float()
        mol_scaled = torch.tensor(
            mol_features_scaler.transform(batch.mol_features.cpu()),
            dtype=torch.float32, device=device
        )
        y = batch.y.float()
        output = model(x, edge_index, edge_attr, batch.batch, mol_scaled)
        loss = loss_fn(output, y)
        total_loss += torch.sqrt(loss).item()
    return total_loss / batch_index

In [None]:
# Training loop
epochs = 100
t = tqdm(range(trained_epochs+1, epochs+trained_epochs+1), total=epochs, desc="Training")
for epoch in t:
    train_loss = train(train_dataloader)
    train_losses.append(train_loss)
    
    val_loss = validation(val_dataloader)
    val_losses.append(val_loss)
    
    scheduler.step()
    trained_epochs += 1

In [None]:
# Plot the training loss
figure = plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train loss')
plt.plot(val_losses, label='Validation loss')
plt.xlabel('Epoch')
plt.ylabel('BCE with logits loss')
plt.legend()
plt.grid(True)

In [None]:
# Test function
@torch.no_grad()
def test(dataloader):
    model.eval()
    y_test, y_pred, smiles_arr, solvent_smiles_arr = [], [], [], []
    for batch in dataloader:
        batch = batch.to(device)
        x = batch.x.float()
        edge_index = batch.edge_index.long()
        edge_attr = batch.edge_attr.float()
        mol_scaled = torch.tensor(
            mol_features_scaler.transform(batch.mol_features.cpu()),
            dtype=torch.float32, device=device
        )
        y = batch.y.float()
        output = model(x, edge_index, edge_attr, batch.batch, mol_scaled)
        y_test.append(y.cpu())
        y_pred.append(output.cpu())
        smiles_arr.extend(batch.smiles)
        solvent_smiles_arr.extend(batch.solvent_smiles)
    y_test = torch.cat(y_test, dim=0)
    y_pred = torch.cat(y_pred, dim=0)
    return y_test, y_pred, smiles_arr, solvent_smiles_arr

In [None]:
# Get the output scaler
output_scaler = chromophores_dataset.output_scaler

# Run the model forward to get the test result
global test_smiles_arr
y_test_scaled, y_pred_scaled, test_smiles_arr, test_solvent_smiles_arr = test(test_dataloader)

# Scale the outputs back to the original range
global y_test
global y_pred
y_test = output_scaler.inverse_transform(y_test_scaled.reshape(-1,1)).reshape(-1)
y_pred = output_scaler.inverse_transform(y_pred_scaled.reshape(-1,1)).reshape(-1)

# Evaluate the model using different metrics
mae = mean_absolute_error(y_test, y_pred)
mse = mean_squared_error(y_test, y_pred)
rmse = mse**0.5
r2 = r2_score(y_test, y_pred)

print(f'MAE: {mae}\nMSE: {mse}\nRMSE: {rmse}\nR2: {r2}')

In [None]:
# Visualization
scatter_plot = plt.figure(figsize=(8, 8))
plt.scatter(y_test, y_pred)
plt.xlabel('actual value')
plt.ylabel('predicted value')
plt.grid(True)

In [None]:
# Save the model checkpoint
def save_checkpoint(model_name):
    checkpoint = {
        'state_dict': model.state_dict()
    }
    checkpoint_dir = './Checkpoints'
    checkpoint_name = f'{model_name}_{trained_epochs}.ckpt'
    os.makedirs(checkpoint_dir, exist_ok=True)
    torch.save(checkpoint, checkpoint_dir + '/' + checkpoint_name)
    return f'Model was saved to {checkpoint_dir}/{checkpoint_name}.'

save_checkpoint(chromophores_dataset_name)

In [None]:
# Load the checkpoint
def load_checkpoint(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, weights_only=False)
    model.load_state_dict(checkpoint['state_dict'])

    global train_losses
    global val_losses
    global trained_epochs
    train_losses = []
    val_losses = []
    trained_epochs = 0

    return f'Checkpoint loaded: {os.path.basename(checkpoint_path)}.'

load_checkpoint(f'{chromophores_dataset_name}_10.ckpt')

In [None]:
# Export losses
df = pd.DataFrame()
df['train_losses'] = train_losses
df['val_losses'] = val_losses
file_path = f'./{chromophores_dataset.dataset_name}_{trained_epochs}_losses.csv'
df.to_csv(file_path)

In [None]:
# Export scatter plot
df = pd.DataFrame()
df['y_test'] = y_test.tolist()
df['y_pred'] = y_pred.tolist()
df['Chromophore'] = test_smiles_arr
df['Solvent'] = test_solvent_smiles_arr
file_path = f'./{chromophores_dataset.dataset_name}_{trained_epochs}_eval.csv'
df.to_csv(file_path)