In [1]:
import os
import time
import pickle
import requests
import itertools
import optuna
import csv
from scipy.stats import mannwhitneyu
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
tqdm.pandas()

from rdkit import RDLogger
from rdkit import Chem
from rdkit.Chem import AllChem, Draw, rdMolTransforms
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.drawOptions.addAtomIndices = True
IPythonConsole.molSize = 300,300
 
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, confusion_matrix, matthews_corrcoef
from sklearn.model_selection import StratifiedKFold
import networkx as nx

import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import SAGEConv, GCNConv, GATConv, GINConv
from torch_geometric.data import Data, Dataset, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch.utils.data import Subset

In [2]:
## Customised preprocessing (Mols-To-Graph function)

class Graph_basic(InMemoryDataset):
    def __init__(self, dataframe, root, smiles_col='smiles', label_col='label', test=False, transform=None, pre_transform=None):
        self.test = test
        self.dataframe = dataframe
        self.smiles_col = smiles_col
        self.label_col = label_col
        self._data = None
        self.error_indices = []
        super(Graph_basic, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        return 'dataframe'

    @property
    def processed_file_names(self):
        return ['data_test.pt' if self.test else 'data.pt']

    def download(self):
        pass

    def process(self):
        data_list = []
        for index, mol in tqdm(self.dataframe.iterrows(), total=self.dataframe.shape[0]):
            try:
                mol_obj = Chem.MolFromSmiles(mol[self.smiles_col])
                node_feats = self._get_node_features(mol_obj)
                edge_feats = self._get_edge_features(mol_obj)
                edge_index = self._get_adjacency_info(mol_obj)
                label = self._get_labels(mol[self.label_col])

                data = Data(x=node_feats, 
                            edge_index=edge_index,
                            edge_attr=edge_feats,
                            y=label,
                            smiles=mol[self.smiles_col])
                data_list.append(data)
            except Exception as e:
                print(f"Error processing molecule at index {index}: {e}")
                self.error_indices.append(index)

        if self.test:
            torch.save(data_list, os.path.join(self.processed_dir, 'data_test.pt'))
        else:
            torch.save(data_list, os.path.join(self.processed_dir, 'data.pt'))

    def _get_node_features(self, mol):
        """ 
        Return a matrix / 2D array of the shape [Number of Nodes, Node Feature size]
        with atomic number as the only node feature.
        """
        all_node_feats = []

        for atom in mol.GetAtoms():
            node_feats = []    
            node_feats.append(atom.GetAtomicNum())
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        Return a matrix / 2D array of the shape [Number of edges, Edge Feature size]
        with bond type as the only edge feature.
        """
        all_edge_feats = []
        
        for bond in mol.GetBonds():
            edge_feats = []
            edge_feats.append(bond.GetBondTypeAsDouble())
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, label):
        label = np.asarray([[label]])
        return torch.tensor(label, dtype=torch.int64)

    def len(self):
        return self.dataframe.shape[0]

    def get(self, idx):
        if self._data is None:
            if self.test:
                self._data = torch.load(os.path.join(self.processed_dir, 'data_test.pt'))
            else:
                self._data = torch.load(os.path.join(self.processed_dir, 'data.pt'))
        return self._data[idx]

## Customised preprocessing (Mols-To-Graph function)

class Graph_custom(InMemoryDataset):
    def __init__(self, dataframe, root, smiles_col='smiles', label_col='label', test=False, transform=None, pre_transform=None):
        self.test = test
        self.dataframe = dataframe
        self.smiles_col = smiles_col
        self.label_col = label_col
        self._data = None
        self.error_indices = []  # bad indices: To keep track of error indices
        super(Graph_custom, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """
        (The download func. is not implemented here)  
        """
        return 'dataframe'

    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped"""
        return ['data_test.pt' if self.test else 'data.pt']

    def download(self):
        pass

    def process(self):
        data_list = []
        for index, mol in tqdm(self.dataframe.iterrows(), total=self.dataframe.shape[0]):
            try:
                mol_obj = Chem.MolFromSmiles(mol[self.smiles_col])
                mol_obj = Chem.AddHs(mol_obj)
                AllChem.EmbedMolecule(mol_obj, randomSeed=42, 
                                      useRandomCoords = True, maxAttempts = 5000 # Use this when Bad Conformer ID error
                                     )
                AllChem.MMFFOptimizeMolecule(mol_obj)
                mol_obj = Chem.RemoveHs(mol_obj)
                AllChem.ComputeGasteigerCharges(mol_obj)
                
                ################################################################
                node_feats = self._get_node_features(mol_obj)
                edge_feats = self._get_edge_features(mol_obj)
                edge_index = self._get_adjacency_info(mol_obj)
                label = self._get_labels(mol[self.label_col])

                data = Data(x=node_feats, 
                            edge_index=edge_index,
                            edge_attr=edge_feats,
                            y=label,
                            smiles=mol[self.smiles_col]
                            ) 
                data_list.append(data)
            except Exception as e:
                print(f"Error processing molecule at index {index}: {e}")
                self.error_indices.append(index)  # bad_indices: Track the index of the molecule that caused an error

        if self.test:
            torch.save(data_list, os.path.join(self.processed_dir, 'data_test.pt'))
        else:
            torch.save(data_list, os.path.join(self.processed_dir, 'data.pt'))

    def _get_node_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_node_feats = []

        for atom in mol.GetAtoms():
            node_feats = []
            # Feature 1: Atomic number        
            node_feats.append(atom.GetAtomicNum())
            # Feature 2: Atom degree -> Number of directly-bonded neighbours
            node_feats.append(atom.GetDegree())
            # Feature 3: Formal charge -> charge of the atom
            node_feats.append(atom.GetFormalCharge())
            # Feature 4: Hybridization -> hybridization state i.e. sp3
            node_feats.append(atom.GetHybridization())
            # Feature 5: Aromaticity
            node_feats.append(atom.GetIsAromatic())
            # Feature 6: Total Num Hs
            node_feats.append(atom.GetTotalNumHs())
            # Feature 7: Radical Electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            # Feature 8: In Ring
            node_feats.append(atom.IsInRing())
            # Feature 9: Chirality
            node_feats.append(atom.GetChiralTag())

            # Feature 10: Gasteiger Charges
            node_feats.append(atom.GetDoubleProp("_GasteigerCharge"))
            # Feature 11: Total Valence
            node_feats.append(atom.GetTotalValence())
            # Feature 12: Explicit Valence
            node_feats.append(atom.GetExplicitValence())

            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        conf = mol.GetConformer() # will be used to calculate bond length, force field optimised conformer
        all_edge_feats = []
        
        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Bond length
            edge_feats.append(np.round(rdMolTransforms.GetBondLength(conf, bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()),3))
            # Feature 3: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, label):
        label = np.asarray([[label]])
        return torch.tensor(label, dtype=torch.int64)

    def len(self):
        return self.dataframe.shape[0]

    def get(self, idx):
        if self._data is None: 
            if self.test:
                self._data = torch.load(os.path.join(self.processed_dir, 'data_test.pt'))
            else:
                self._data = torch.load(os.path.join(self.processed_dir, 'data.pt'))
        return self._data[idx]

In [3]:
class GNNModel(torch.nn.Module):
    def __init__(self, conv_type, hidden_channels, num_features, batch_norm=False, weight_init=True, dropout_rate=0.0):
        super(GNNModel, self).__init__()
        self.conv1 = conv_type(num_features, hidden_channels*2)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels*2) if batch_norm else None #BatchNorm1D -> GraphNorm
        
        self.conv2 = conv_type(hidden_channels*2, hidden_channels)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels) if batch_norm else None
        
        self.conv3 = conv_type(hidden_channels, hidden_channels//2)
        
        self.dropout = torch.nn.Dropout(dropout_rate) if dropout_rate > 0 else None
        self.lin = torch.nn.Linear(hidden_channels//2, 1)
        if weight_init:
            self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, torch_geometric.nn.MessagePassing):
            for param in module.parameters():
                if param.dim() > 1:
                    torch.nn.init.xavier_uniform_(param)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        if self.bn1:
            x = self.bn1(x)
        x = F.relu(x)
            
        x = self.conv2(x, edge_index)
        if self.bn2:
            x = self.bn2(x)
        x = F.relu(x)
            
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        if self.dropout:
            x = self.dropout(x)
            
        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.lin(x)
        return x

class GINModel(torch.nn.Module):
    def __init__(self, hidden_channels, num_features, batch_norm=False, weight_init=True, dropout_rate=0.0):
        super(GINModel, self).__init__()
        # Define MLP for GINConv
        def mlp(input_dim, output_dim):
            return torch.nn.Sequential(
                torch.nn.Linear(input_dim, output_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(output_dim, output_dim)
            )

        self.conv1 = torch_geometric.nn.GINConv(mlp(num_features, hidden_channels*2))
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels*2) if batch_norm else None
        
        self.conv2 = torch_geometric.nn.GINConv(mlp(hidden_channels*2, hidden_channels))
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels) if batch_norm else None
        
        self.conv3 = torch_geometric.nn.GINConv(mlp(hidden_channels, hidden_channels//2))
        
        self.dropout = torch.nn.Dropout(dropout_rate) if dropout_rate > 0 else None
        self.lin = torch.nn.Linear(hidden_channels//2, 1)

        if weight_init:
            self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, GINConv):
            for layer in module.nn:
                if isinstance(layer, torch.nn.Linear):
                    torch.nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        torch.nn.init.zeros_(layer.bias)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        if self.bn1:
            x = self.bn1(x)
            x = F.relu(x)
            
        x = self.conv2(x, edge_index)
        if self.bn2:
            x = self.bn2(x)
            x = F.relu(x)
            
        x = self.conv3(x, edge_index)
        if self.dropout:
            x = self.dropout(x)
            
        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.lin(x)
        return x

def train(model, loader, optimiser, criterion, device, threshold = 0.5):
    model.train()
    total_loss = 0
    correct = 0
    for data in loader:
        data = data.to(device)
        optimiser.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y.view(-1, 1).float())
        loss.backward()
        optimiser.step()
        total_loss += loss.item()
        pred = (out > threshold).float()
        correct += pred.eq(data.y.view(-1, 1)).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

def evaluate(model, loader, criterion, device, threshold = 0.5):
    model.eval()
    total_loss = 0
    correct = 0
    preds, labels = [], []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y.view(-1, 1).float())
            total_loss += loss.item()
            pred = (out > threshold).float()
            correct += pred.eq(data.y.view(-1, 1)).sum().item()
            preds.append(out)
            labels.append(data.y.view(-1, 1))
    preds = torch.cat(preds, dim=0)
    labels = torch.cat(labels, dim=0)
    return total_loss / len(loader), correct / len(loader.dataset), roc_auc_score(labels.cpu(), preds.cpu()), f1_score(labels.cpu(), (preds > threshold).float().cpu()), matthews_corrcoef(labels.cpu(), (preds > threshold).float().cpu())

class EarlyStopping:
    def __init__(self, patience=10, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0

    def __call__(self, val_score, model):
        score = -val_score
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_score, model)
        elif score < self.best_score - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_score, model)
            self.counter = 0

    def save_checkpoint(self, val_score, model):
        self.best_model = model

In [5]:
def get_fold_indices(pytorch_custom_dataset, seed=42):
    num_samples = len(pytorch_custom_dataset)
    indices = np.arange(num_samples)
    targets = [data.y.item() for data in pytorch_custom_dataset]

    skf_outer = StratifiedKFold(n_splits=4, shuffle=True, random_state=seed)
    outer_folds = []

    for train_indices, test_indices in skf_outer.split(indices, targets):
        train_targets = np.array(targets)[train_indices]
        skf_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
        inner_folds = []
        for inner_train_idx, val_idx in skf_inner.split(np.arange(len(train_indices)), train_targets):
            inner_train_indices = train_indices[inner_train_idx]
            val_indices = train_indices[val_idx]
            inner_folds.append((inner_train_indices, val_indices))
        outer_folds.append((train_indices, test_indices, inner_folds)) 
    return outer_folds

def objective(hyperparameters, model_class, conv_type, pytorch_custom_dataset, device, results_dir, model_name, outer_fold, inner_fold, fold_indices, use_early_stopping):
    learning_rate = hyperparameters['learning_rate']
    hidden_channels = hyperparameters['hidden_channels']
    batch_size = hyperparameters['batch_size']
    dropout_rate = hyperparameters['dropout_rate']
    num_epochs = 100 

    torch.manual_seed(1)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1)

    train_indices, test_indices, inner_folds = fold_indices[outer_fold]
    inner_train_indices, val_indices = inner_folds[inner_fold]

    train_dataset = Subset(pytorch_custom_dataset, inner_train_indices)
    val_dataset = Subset(pytorch_custom_dataset, val_indices)
    test_dataset = Subset(pytorch_custom_dataset, test_indices)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    num_features = pytorch_custom_dataset[0].num_features 

    if model_class == GINModel:
        model = model_class(
            hidden_channels=hidden_channels,
            num_features=num_features,
            batch_norm=False,
            weight_init=True,
            dropout_rate=dropout_rate
        ).to(device)
    else:
        model = model_class(
            conv_type=conv_type,
            hidden_channels=hidden_channels,
            num_features=num_features,
            batch_norm=False,
            weight_init=True,
            dropout_rate=dropout_rate
        ).to(device)

    optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.BCEWithLogitsLoss()

    early_stopping = EarlyStopping(patience = 10, delta=0)
    
    for epoch in range(1, num_epochs + 1):
        train_loss, train_acc = train(model, train_loader, optimiser, criterion, device)
        val_loss, val_acc, val_auc, _, _ = evaluate(model, val_loader, criterion, device)
        if use_early_stopping:
            early_stopping(val_auc, model)
            if early_stopping.early_stop:
                print(f"Early stopping at epoch {epoch}")
                break
                
    best_model = early_stopping.best_model if use_early_stopping else model
    val_loss, val_acc, val_auc, _, _ = evaluate(best_model, val_loader, criterion, device)
    
    return val_auc

def grid_search(model_class, conv_type, pytorch_custom_dataset, results_dir, model_name, use_early_stopping):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs(results_dir, exist_ok=True)

    fold_indices = get_fold_indices(pytorch_custom_dataset, seed=42)

    batch_size_list = [32, 64, 128]
    hidden_channels_list = [64, 128, 256]
    learning_rate_list = [0.001, 0.0001, 0.00001]
    dropout_rate_list = [0.1, 0.2, 0.3, 0.4, 0.5]

    hyperparameter_grid = list(itertools.product(
        batch_size_list,
        hidden_channels_list,
        learning_rate_list,
        dropout_rate_list
    ))

    results = []
    for outer_fold in range(4):
        print(f"Outer Fold {outer_fold + 1}/4")

        for inner_fold in range(5):
            print(f"  Inner Fold {inner_fold + 1}/5")

            for batch_size, hidden_channels, learning_rate, dropout_rate in hyperparameter_grid:
                hyperparameters_dict = {
                    'batch_size': batch_size,
                    'hidden_channels': hidden_channels,
                    'learning_rate': learning_rate,
                    'dropout_rate': dropout_rate
                }

                val_auc = objective(
                    hyperparameters=hyperparameters_dict,
                    model_class=model_class,
                    conv_type=conv_type,
                    pytorch_custom_dataset=pytorch_custom_dataset,
                    device=device,
                    results_dir=results_dir,
                    model_name=model_name,
                    outer_fold=outer_fold,
                    inner_fold=inner_fold,
                    fold_indices=fold_indices,
                    use_early_stopping = use_early_stopping
                )
                print(f"Testing hyperparameters: bs={batch_size}, hc={hidden_channels}, lr={learning_rate}, dr={dropout_rate} / val_auc: {val_auc}")
                
                result = {
                    'model_name': model_name,
                    'outer_fold': outer_fold,
                    'inner_fold': inner_fold,
                    'batch_size': batch_size,
                    'hidden_channels': hidden_channels,
                    'learning_rate': learning_rate,
                    'dropout_rate': dropout_rate,
                    'val_auc': val_auc
                }

                results.append(result)

    return results

def run_grid_search(models, pytorch_custom_dataset, results_dir, use_early_stopping):
    all_results = []

    for name, model_class in models.items():
        print(f"Running Grid Search for {name} model...")

        if "GCN" in name:
            conv_type = GCNConv
        elif "GAT" in name:
            conv_type = GATConv
        elif "GraphSAGE" in name:
            conv_type = SAGEConv
        else:
            conv_type = None 

        results = grid_search(
            model_class=model_class,
            conv_type=conv_type,
            pytorch_custom_dataset=pytorch_custom_dataset,
            results_dir=results_dir,
            model_name=name,
            use_early_stopping = use_early_stopping
        )

        all_results.extend(results)

    return all_results


#### Drug-induced liver injury

In [None]:
# Generated from DILIGeNN Data Preprocessing.ipynb
df_dili_filtered_cleaned = pd.read_csv('chem_data/DILI/DILIst_standardised_cleaned.csv', index_col = 0)

dlst_min = Graph_basic(dataframe = df_dili_filtered_cleaned,
                              smiles_col = 'smiles',
                             label_col = 'label',
                             root = 'custom_data/minimal_feature/dlst_min'
                            )
dlst_cus_2_std = Graph_custom(dataframe = df_dili_filtered_cleaned,
                              smiles_col = 'smiles_std',
                             label_col = 'label',
                             root = 'custom_data/DILI/cus2_std/'
                            )

In [None]:
# dlst_cus_2_std
models = {
    'GCN_Optimised': GNNModel,
    'GAT_Optimised': GNNModel,
    'GraphSAGE_Optimised': GNNModel,
    'GIN_Optimised': GINModel
}

results_dir = './results/grid_search/dlst'

# Run Grid Search
all_results = run_grid_search(
    models,
    dlst_cus_2_std,
    results_dir,
    use_early_stopping = True
)


results_df = pd.DataFrame(all_results)
results_df.to_csv(os.path.join(results_dir, 'grid_search_results.csv'), index=False)

with open(os.path.join(results_dir, 'grid_search_results.pkl'), 'wb') as f:
    pickle.dump(results_df, f)

In [None]:
# dlst_min
models = {
    'GCN_Optimised': GNNModel,
    'GAT_Optimised': GNNModel,
    'GraphSAGE_Optimised': GNNModel,
    'GIN_Optimised': GINModel
}

results_dir = './results/grid_search/dlst_min'

# Run Grid Search
all_results = run_grid_search(
    models,
    dlst_min,
    results_dir,
    use_early_stopping = True
)

# Create DataFrame 
results_df = pd.DataFrame(all_results)

# Save DataFrame
results_df.to_csv(os.path.join(results_dir, 'grid_search_results.csv'), index=False)

# Optionally, save the DataFrame as a pickle file
with open(os.path.join(results_dir, 'grid_search_results.pkl'), 'wb') as f:
    pickle.dump(results_df, f)

#### Blood-Brain Barrier Permeability

In [None]:
df_bbbp_filtered_cleaned = pd.read_csv('chem_data/bbbp_standardised_cleaned.csv', index_col = 0)

bbbp_min = Graph_basic(dataframe = df_bbbp_filtered_cleaned,
                              smiles_col = 'smiles',
                             label_col = 'label',
                             root = 'custom_data/minimal_feature/bbbp_min'
                            )
bbbp_cus_2_std = Graph_custom(dataframe = df_bbbp_filtered_cleaned,
                              smiles_col = 'smiles_std',
                             label_col = 'label',
                             root = 'custom_data/bbbp/cus2_std/'
                            )

In [None]:
# bbbp_cus_2_std
models = {
    'GCN_Optimised': GNNModel,
    'GAT_Optimised': GNNModel,
    'GraphSAGE_Optimised': GNNModel,
    'GIN_Optimised': GINModel
}

results_dir = './results/grid_search/bbbp'

all_results = run_grid_search(
    models,
    bbbp_cus_2_std,
    results_dir,
    use_early_stopping = True
)

results_df = pd.DataFrame(all_results)
results_df.to_csv(os.path.join(results_dir, 'grid_search_results.csv'), index=False)

with open(os.path.join(results_dir, 'grid_search_results.pkl'), 'wb') as f:
    pickle.dump(results_df, f)

In [None]:
# bbbp_min
models = {
    'GCN_Optimised': GNNModel,
    'GAT_Optimised': GNNModel,
    'GraphSAGE_Optimised': GNNModel,
    'GIN_Optimised': GINModel
}

results_dir = './results/grid_search/bbbp_min'

all_results = run_grid_search(
    models,
    bbbp_min,
    results_dir,
    use_early_stopping = True
)

results_df = pd.DataFrame(all_results)
results_df.to_csv(os.path.join(results_dir, 'grid_search_results.csv'), index=False)

with open(os.path.join(results_dir, 'grid_search_results.pkl'), 'wb') as f:
    pickle.dump(results_df, f)

#### BACE Activity

In [None]:
df_bace_filtered_cleaned = pd.read_csv('chem_data/bace_standardised_cleaned.csv', index_col = 0)

bace_min = Graph_basic(dataframe = df_bace_filtered_cleaned,
                              smiles_col = 'smiles',
                             label_col = 'label',
                             root = 'custom_data/minimal_feature/bace_min'
                            )
bace_cus_2_std = Graph_custom(dataframe = df_bace_filtered_cleaned,
                              smiles_col = 'smiles_std',
                             label_col = 'label',
                             root = 'custom_data/bace/cus2_std/'

In [None]:
# bace_cus_2_std
models = {
    'GCN_Optimised': GNNModel,
    'GAT_Optimised': GNNModel,
    'GraphSAGE_Optimised': GNNModel,
    'GIN_Optimised': GINModel
}

results_dir = './results/grid_search/bace'

all_results = run_grid_search(
    models,
    bace_cus_2_std,
    results_dir,
    use_early_stopping = True
)

results_df = pd.DataFrame(all_results)
results_df.to_csv(os.path.join(results_dir, 'grid_search_results.csv'), index=False)

with open(os.path.join(results_dir, 'grid_search_results.pkl'), 'wb') as f:
    pickle.dump(results_df, f)

In [None]:
# bace_min
models = {
    'GCN_Optimised': GNNModel,
    'GAT_Optimised': GNNModel,
    'GraphSAGE_Optimised': GNNModel,
    'GIN_Optimised': GINModel
}

results_dir = './results/grid_search/bace_min'

all_results = run_grid_search(
    models,
    bace_min,
    results_dir,
    use_early_stopping = True
)

results_df = pd.DataFrame(all_results)
results_df.to_csv(os.path.join(results_dir, 'grid_search_results.csv'), index=False)

with open(os.path.join(results_dir, 'grid_search_results.pkl'), 'wb') as f:
    pickle.dump(results_df, f)

#### ClinTox

In [None]:
class Graph_custom(InMemoryDataset):
    def __init__(self, dataframe, root, smiles_col='smiles', label_col='label', test=False, transform=None, pre_transform=None):
        self.test = test
        self.dataframe = dataframe
        self.smiles_col = smiles_col
        self.label_col = label_col
        self._data = None
        self.error_indices = []
        super(Graph_custom, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """
        (The download func. is not implemented here)  
        """
        return 'dataframe'

    @property
    def processed_file_names(self):
        """If these files are found in raw_dir, processing is skipped"""
        return ['data_test.pt' if self.test else 'data.pt']

    def download(self):
        pass

    def process(self):
        data_list = []
        for index, mol in tqdm(self.dataframe.iterrows(), total=self.dataframe.shape[0]):
            try:
                mol_obj = Chem.MolFromSmiles(mol[self.smiles_col])

                mol_obj = Chem.AddHs(mol_obj)
                AllChem.EmbedMolecule(mol_obj, randomSeed=42, 
                                      useRandomCoords=True, maxAttempts=5000)
                AllChem.MMFFOptimizeMolecule(mol_obj)
                mol_obj = Chem.RemoveHs(mol_obj)
                AllChem.ComputeGasteigerCharges(mol_obj)

                node_feats = self._get_node_features(mol_obj)
                edge_feats = self._get_edge_features(mol_obj)
                edge_index = self._get_adjacency_info(mol_obj)
                
                label = self._get_labels(mol[self.label_col])

                data = Data(x=node_feats, 
                            edge_index=edge_index,
                            edge_attr=edge_feats,
                            y=label,
                            smiles=mol[self.smiles_col])
                data_list.append(data)
            except Exception as e:
                print(f"Error processing molecule at index {index}: {e}")
                self.error_indices.append(index)

        if self.test:
            torch.save(data_list, os.path.join(self.processed_dir, 'data_test.pt'))
        else:
            torch.save(data_list, os.path.join(self.processed_dir, 'data.pt'))

    def _get_node_features(self, mol):
        """ 
        Returns a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_node_feats = []

        for atom in mol.GetAtoms():
            node_feats = []
            # Feature 1: Atomic number        
            node_feats.append(atom.GetAtomicNum())
            # Feature 2: Atom degree -> Number of directly-bonded neighbours
            node_feats.append(atom.GetDegree())
            # Feature 3: Formal charge -> charge of the atom
            node_feats.append(atom.GetFormalCharge())
            # Feature 4: Hybridization -> hybridization state i.e. sp3
            node_feats.append(int(atom.GetHybridization()))
            # Feature 5: Aromaticity
            node_feats.append(int(atom.GetIsAromatic()))
            # Feature 6: Total Num Hs
            node_feats.append(atom.GetTotalNumHs())
            # Feature 7: Radical Electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            # Feature 8: In Ring
            node_feats.append(int(atom.IsInRing()))
            # Feature 9: Chirality
            node_feats.append(int(atom.GetChiralTag()))

            # Feature 10: Gasteiger Charges
            node_feats.append(atom.GetDoubleProp("_GasteigerCharge"))
            # Feature 11: Total Valence
            node_feats.append(atom.GetTotalValence())
            # Feature 12: Explicit Valence
            node_feats.append(atom.GetExplicitValence())

            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        Returns a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        conf = mol.GetConformer()
        all_edge_feats = []
        
        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Bond length
            edge_feats.append(np.round(rdMolTransforms.GetBondLength(conf, bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()), 3))
            # Feature 3: Rings
            edge_feats.append(int(bond.IsInRing()))
            # Append edge features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, label):
        if isinstance(label, str):
            label = ast.literal_eval(label)
        label = np.array(label, dtype=np.int64) 
        return torch.tensor(label, dtype=torch.int64) 

    def len(self):
        return self.dataframe.shape[0]

    def get(self, idx):
        if self._data is None:
            if self.test:
                self._data = torch.load(os.path.join(self.processed_dir, 'data_test.pt'))
            else:
                self._data = torch.load(os.path.join(self.processed_dir, 'data.pt'))
        return self._data[idx]

class Graph_basic(InMemoryDataset):
    def __init__(self, dataframe, root, smiles_col='smiles', label_col='label', test=False, transform=None, pre_transform=None):
        self.test = test
        self.dataframe = dataframe
        self.smiles_col = smiles_col
        self.label_col = label_col
        self._data = None
        self.error_indices = []
        super(Graph_basic, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        return 'dataframe'

    @property
    def processed_file_names(self):
        return ['data_test.pt' if self.test else 'data.pt']

    def download(self):
        pass

    def process(self):
        data_list = []
        for index, mol in tqdm(self.dataframe.iterrows(), total=self.dataframe.shape[0]):
            try:
                mol_obj = Chem.MolFromSmiles(mol[self.smiles_col])

                # Get node features
                node_feats = self._get_node_features(mol_obj)
                # Get edge features
                edge_feats = self._get_edge_features(mol_obj)
                # Get adjacency info
                edge_index = self._get_adjacency_info(mol_obj)
                # Get labels info
                label = self._get_labels(mol[self.label_col])

                # Create data object
                data = Data(x=node_feats, 
                            edge_index=edge_index,
                            edge_attr=edge_feats,
                            y=label,
                            smiles=mol[self.smiles_col])
                data_list.append(data)
            except Exception as e:
                print(f"Error processing molecule at index {index}: {e}")
                self.error_indices.append(index)

        # Save all data objects into one file
        if self.test:
            torch.save(data_list, os.path.join(self.processed_dir, 'data_test.pt'))
        else:
            torch.save(data_list, os.path.join(self.processed_dir, 'data.pt'))

    def _get_node_features(self, mol):
        """ 
        Return a matrix / 2D array of the shape [Number of Nodes, Node Feature size]
        with atomic number as the only node feature.
        """
        all_node_feats = []

        for atom in mol.GetAtoms():
            node_feats = []
            # Feature: Atomic number        
            node_feats.append(atom.GetAtomicNum())
            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        Return a matrix / 2D array of the shape [Number of edges, Edge Feature size]
        with bond type as the only edge feature.
        """
        all_edge_feats = []
        
        for bond in mol.GetBonds():
            edge_feats = []
            # Feature: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Append edge features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, label):
        if isinstance(label, str):
            label = ast.literal_eval(label)
        label = np.array(label, dtype=np.int64)
        return torch.tensor(label, dtype=torch.int64) 

    def len(self):
        return self.dataframe.shape[0]

    def get(self, idx):
        if self._data is None:
            if self.test:
                self._data = torch.load(os.path.join(self.processed_dir, 'data_test.pt'))
            else:
                self._data = torch.load(os.path.join(self.processed_dir, 'data.pt'))
        return self._data[idx]

In [None]:
# Model Classes

class GNNModel(torch.nn.Module):
    def __init__(self, conv_type, hidden_channels, num_features, num_classes=3, batch_norm=False, weight_init=True, dropout_rate=0.0):  # changed labels
        super(GNNModel, self).__init__()
        self.conv1 = conv_type(num_features, hidden_channels*2)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels*2) if batch_norm else None

        self.conv2 = conv_type(hidden_channels*2, hidden_channels)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels) if batch_norm else None

        self.conv3 = conv_type(hidden_channels, hidden_channels//2)

        self.dropout = torch.nn.Dropout(dropout_rate) if dropout_rate > 0 else None
        self.lin = torch.nn.Linear(hidden_channels//2, num_classes)  # Adjusted output layer for multi-class classification
        if weight_init:
            self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, torch_geometric.nn.MessagePassing):
            for param in module.parameters():
                if param.dim() > 1:
                    torch.nn.init.xavier_uniform_(param)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        if self.bn1:
            x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x, edge_index)
        if self.bn2:
            x = self.bn2(x)
        x = F.relu(x)

        x = self.conv3(x, edge_index)
        x = F.relu(x)
        if self.dropout:
            x = self.dropout(x)

        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.lin(x)
        return x  # Outputs logits for each class

class GINModel(torch.nn.Module):
    def __init__(self, hidden_channels, num_features, num_classes=3, batch_norm=False, weight_init=True, dropout_rate=0.0):  # changed labels
        super(GINModel, self).__init__()
        # Define MLP for GINConv
        def mlp(input_dim, output_dim):
            return torch.nn.Sequential(
                torch.nn.Linear(input_dim, output_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(output_dim, output_dim)
            )

        self.conv1 = torch_geometric.nn.GINConv(mlp(num_features, hidden_channels*2))
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels*2) if batch_norm else None

        self.conv2 = torch_geometric.nn.GINConv(mlp(hidden_channels*2, hidden_channels))
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels) if batch_norm else None

        self.conv3 = torch_geometric.nn.GINConv(mlp(hidden_channels, hidden_channels//2))

        self.dropout = torch.nn.Dropout(dropout_rate) if dropout_rate > 0 else None
        self.lin = torch.nn.Linear(hidden_channels//2, num_classes)  # Adjusted output layer for multi-class classification

        if weight_init:
            self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, GINConv):
            for layer in module.nn:
                if isinstance(layer, torch.nn.Linear):
                    torch.nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        torch.nn.init.zeros_(layer.bias)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        if self.bn1:
            x = self.bn1(x)
            x = F.relu(x)

        x = self.conv2(x, edge_index)
        if self.bn2:
            x = self.bn2(x)
            x = F.relu(x)

        x = self.conv3(x, edge_index)
        if self.dropout:
            x = self.dropout(x)

        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.lin(x)
        return x  


# Training and Evaluation Functions

def train(model, loader, optimiser, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    for data in loader:
        data = data.to(device)
        optimiser.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y.squeeze()) # Use CrossEntropyLoss with class indices
        loss.backward()
        optimiser.step()
        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += pred.eq(data.y.squeeze()).sum().item()
    accuracy = correct / len(loader.dataset)
    return total_loss / len(loader), accuracy

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    preds = []
    labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y.squeeze()) # Use CrossEntropyLoss with class indices
            total_loss += loss.item()
            pred = out.argmax(dim=1)  # changed labels
            correct += pred.eq(data.y.squeeze()).sum().item()
            preds.extend(pred.cpu().numpy())
            labels.extend(data.y.squeeze().cpu().numpy())
    accuracy = correct / len(loader.dataset)

    try:
        all_probs = []
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            probs = F.softmax(out, dim=1)
            all_probs.append(probs.cpu().detach())
        all_probs = torch.cat(all_probs, dim=0).numpy()
        auc = roc_auc_score(labels, all_probs, multi_class='ovr', average='macro')
    except ValueError:
        auc = float('nan')
    f1 = f1_score(labels, preds, average='macro')
    mcc = matthews_corrcoef(labels, preds)
    return total_loss / len(loader), accuracy, auc, f1, mcc


# Early Stopping Class

class EarlyStopping:
    def __init__(self, patience=10, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0

    def __call__(self, val_score, model):
        score = -val_score
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_score, model)
        elif score < self.best_score - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_score, model)
            self.counter = 0

    def save_checkpoint(self, val_score, model):
        self.best_model = model


# Cross-validation and Grid search

def get_fold_indices(pytorch_custom_dataset, seed=42):
    num_samples = len(pytorch_custom_dataset)
    indices = np.arange(num_samples)
    targets = np.array([data.y.item() for data in pytorch_custom_dataset])

    skf_outer = StratifiedKFold(n_splits=4, shuffle=True, random_state=seed)
    outer_folds = []

    for train_indices, test_indices in skf_outer.split(indices, targets):
        train_targets = targets[train_indices]
        skf_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
        inner_folds = []
        for inner_train_idx, val_idx in skf_inner.split(train_indices, train_targets):
            inner_train_indices = train_indices[inner_train_idx]
            val_indices = train_indices[val_idx]
            inner_folds.append((inner_train_indices, val_indices))
        outer_folds.append((train_indices, test_indices, inner_folds))
    return outer_folds

def objective(hyperparameters, model_class, conv_type, pytorch_custom_dataset, device, results_dir, model_name, outer_fold, inner_fold, fold_indices, use_early_stopping):
    # Extract hyperparameters
    learning_rate = hyperparameters['learning_rate']
    hidden_channels = hyperparameters['hidden_channels']
    batch_size = hyperparameters['batch_size']
    dropout_rate = hyperparameters['dropout_rate']
    num_epochs = 100

    torch.manual_seed(1)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1)

    train_indices, test_indices, inner_folds = fold_indices[outer_fold]
    inner_train_indices, val_indices = inner_folds[inner_fold]

    train_dataset = Subset(pytorch_custom_dataset, inner_train_indices)
    val_dataset = Subset(pytorch_custom_dataset, val_indices)
    test_dataset = Subset(pytorch_custom_dataset, test_indices)

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    num_features = pytorch_custom_dataset[0].num_features
    num_classes = 3

    # Init models
    if model_class == GINModel:
        model = model_class(
            hidden_channels=hidden_channels,
            num_features=num_features,
            num_classes=num_classes,
            batch_norm=False,
            weight_init=True,
            dropout_rate=dropout_rate
        ).to(device)
    else:
        model = model_class(
            conv_type=conv_type,
            hidden_channels=hidden_channels,
            num_features=num_features,
            num_classes=num_classes,
            batch_norm=False,
            weight_init=True,
            dropout_rate=dropout_rate
        ).to(device)

    optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)

    criterion = torch.nn.CrossEntropyLoss()  # CE

    # Training loop with early stopping
    early_stopping = EarlyStopping(patience=10, delta=0)
    
    for epoch in range(1, num_epochs + 1):
        train_loss, train_acc = train(model, train_loader, optimiser, criterion, device)
        val_loss, val_acc, val_auc, _, _ = evaluate(model, val_loader, criterion, device)
        if use_early_stopping:
            early_stopping(val_auc, model)
            if early_stopping.early_stop:
                print(f"Early stopping at epoch {epoch}")
                break
                
    best_model = early_stopping.best_model if use_early_stopping else model
    val_loss, val_acc, val_auc, _, _ = evaluate(best_model, val_loader, criterion, device)
    
    return val_auc

def grid_search(model_class, conv_type, pytorch_custom_dataset, results_dir, model_name, use_early_stopping):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs(results_dir, exist_ok=True)

    # Indices for train / val / test data
    fold_indices = get_fold_indices(pytorch_custom_dataset, seed=42)

    # hyperparameter grid
    batch_size_list = [32, 64, 128]
    hidden_channels_list = [64, 128, 256]
    learning_rate_list = [0.001, 0.0001, 0.00001]
    dropout_rate_list = [0.1, 0.2, 0.3, 0.4, 0.5]

    hyperparameter_grid = list(itertools.product(
        batch_size_list,
        hidden_channels_list,
        learning_rate_list,
        dropout_rate_list
    ))

    # List to store results
    results = []

    for outer_fold in range(4):
        print(f"Outer Fold {outer_fold + 1}/4")

        for inner_fold in range(5):
            print(f"  Inner Fold {inner_fold + 1}/5")

            for batch_size, hidden_channels, learning_rate, dropout_rate in hyperparameter_grid:
                hyperparameters_dict = {
                    'batch_size': batch_size,
                    'hidden_channels': hidden_channels,
                    'learning_rate': learning_rate,
                    'dropout_rate': dropout_rate
                }

                val_auc = objective(
                    hyperparameters=hyperparameters_dict,
                    model_class=model_class,
                    conv_type=conv_type,
                    pytorch_custom_dataset=pytorch_custom_dataset,
                    device=device,
                    results_dir=results_dir,
                    model_name=model_name,
                    outer_fold=outer_fold,
                    inner_fold=inner_fold,
                    fold_indices=fold_indices,
                    use_early_stopping=use_early_stopping
                )
                print(f"Testing hyperparameters: bs={batch_size}, hc={hidden_channels}, lr={learning_rate}, dr={dropout_rate} / val_auc: {val_auc}")
                
                # Store the results
                result = {
                    'model_name': model_name,
                    'outer_fold': outer_fold,
                    'inner_fold': inner_fold,
                    'batch_size': batch_size,
                    'hidden_channels': hidden_channels,
                    'learning_rate': learning_rate,
                    'dropout_rate': dropout_rate,
                    'val_auc': val_auc
                }

                results.append(result)

    return results

def run_grid_search(models, pytorch_custom_dataset, results_dir, use_early_stopping):
    all_results = []

    for name, model_class in models.items():
        print(f"Running Grid Search for {name} model...")

        if "GCN" in name:
            conv_type = GCNConv
        elif "GAT" in name:
            conv_type = GATConv
        elif "GraphSAGE" in name:
            conv_type = SAGEConv
        else:
            conv_type = None  # For GINModel, no conv_type is needed

        results = grid_search(
            model_class=model_class,
            conv_type=conv_type,
            pytorch_custom_dataset=pytorch_custom_dataset,
            results_dir=results_dir,
            model_name=name,
            use_early_stopping=use_early_stopping
        )

        all_results.extend(results)

    return all_results

def convert_labels_to_class_indices(pytorch_custom_dataset):
    label_mapping = {
    (1, 0): 0,
    (0, 1): 1,
    (1, 1): 2
}
    for data in pytorch_custom_dataset:
        label_tuple = tuple(data.y.numpy())
        data.y = torch.tensor([label_mapping[label_tuple]], dtype=torch.long)

In [None]:
# Generated from DILIGeNN Data Preprocessing.ipynb
df_clintox_filtered_cleaned = pd.read_csv('chem_data/clintox2_standardised_cleaned.csv', index_col = 0)

clintox_min = Graph_basic(dataframe = df_clintox_filtered_cleaned,
                              smiles_col = 'smiles',
                             label_col = 'label',
                             root = 'custom_data/minimal_feature/clintox2_min'
                            ) 
convert_labels_to_class_indices(clintox_min)

clintox_cus_2_std = Graph_custom(dataframe = df_clintox_filtered_cleaned,
                              smiles_col = 'smiles_std',
                             label_col = 'label',
                             root = 'custom_data/clintox2/cus2_std/'
                            )
convert_labels_to_class_indices(clintox_cus_2_std)

In [None]:
models = {
    'GCN_Optimised': GNNModel,
    'GAT_Optimised': GNNModel,
    'GraphSAGE_Optimised': GNNModel,
    'GIN_Optimised': GINModel
}

pytorch_custom_dataset = clintox_cus_2_std
results_dir = './results/grid_search/clintox2'

all_results = run_grid_search(
    models,
    pytorch_custom_dataset,
    results_dir,
    use_early_stopping = True
)

results_df = pd.DataFrame(all_results)
results_df.to_csv(os.path.join(results_dir, 'grid_search_results.csv'), index=False)

with open(os.path.join(results_dir, 'grid_search_results.pkl'), 'wb') as f:
    pickle.dump(results_df, f)

In [None]:
models = {
    'GCN_Optimised': GNNModel,
    'GAT_Optimised': GNNModel,
    'GraphSAGE_Optimised': GNNModel,
    'GIN_Optimised': GINModel
}

pytorch_custom_dataset = clintox_min
results_dir = './results/grid_search/clintox2_min'

all_results = run_grid_search(
    models,
    pytorch_custom_dataset,
    results_dir,
    use_early_stopping = True
)

results_df = pd.DataFrame(all_results)
results_df.to_csv(os.path.join(results_dir, 'grid_search_results.csv'), index=False)

with open(os.path.join(results_dir, 'grid_search_results.pkl'), 'wb') as f:
    pickle.dump(results_df, f)