In [1]:
import os
import time
import pickle
import requests
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
tqdm.pandas()

from rdkit import RDLogger
from rdkit import Chem
from rdkit.Chem import AllChem, Draw, rdMolTransforms
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.drawOptions.addAtomIndices = True
IPythonConsole.molSize = 300,300

from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, confusion_matrix, matthews_corrcoef
from sklearn.model_selection import StratifiedKFold
import networkx as nx

import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import SAGEConv, GCNConv, GATConv, GINConv
from torch_geometric.data import Data, Dataset, InMemoryDataset#, DataLoader depreciated, use below
from torch_geometric.loader import DataLoader
from torch.utils.data import Subset

In [2]:
class GNNModel(torch.nn.Module):
    def __init__(self, conv_type, hidden_channels, num_features, batch_norm=False, weight_init=True, dropout_rate=0.0):
        super(GNNModel, self).__init__()
        self.conv1 = conv_type(num_features, hidden_channels*2)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels*2) if batch_norm else None #BatchNorm1D -> GraphNorm
        
        self.conv2 = conv_type(hidden_channels*2, hidden_channels)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels) if batch_norm else None
        
        self.conv3 = conv_type(hidden_channels, hidden_channels//2)
        
        self.dropout = torch.nn.Dropout(dropout_rate) if dropout_rate > 0 else None
        self.lin = torch.nn.Linear(hidden_channels//2, 1)
        if weight_init:
            self.apply(self._init_weights)

    def _init_weights(self, module):
        """Apply Xavier initialisation to the weights of a given module."""
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, torch_geometric.nn.MessagePassing):
            # Initialising weights of the MessagePassing (convolution) layers
            for param in module.parameters():
                if param.dim() > 1:  # Only initialise weights, not biases
                    torch.nn.init.xavier_uniform_(param)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        if self.bn1:
            x = self.bn1(x)
        x = F.relu(x)
        #if self.dropout:
        #    x = self.dropout(x)
            
        x = self.conv2(x, edge_index)
        if self.bn2:
            x = self.bn2(x)
        x = F.relu(x)
        #if self.dropout:
        #    x = self.dropout(x)
            
        x = self.conv3(x, edge_index)
        #if self.bn3:
        #    x = self.bn3(x)
        x = F.relu(x)
        if self.dropout:
            x = self.dropout(x)
            
        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.lin(x)
        return x

class GINModel(torch.nn.Module):
    def __init__(self, hidden_channels, num_features, batch_norm=False, weight_init=True, dropout_rate=0.0):
        super(GINModel, self).__init__()
        # Define MLP for GINConv
        def mlp(input_dim, output_dim):
            return torch.nn.Sequential(
                torch.nn.Linear(input_dim, output_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(output_dim, output_dim)
            )

        self.conv1 = torch_geometric.nn.GINConv(mlp(num_features, hidden_channels*2))
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels*2) if batch_norm else None
        
        self.conv2 = torch_geometric.nn.GINConv(mlp(hidden_channels*2, hidden_channels))
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels) if batch_norm else None
        
        self.conv3 = torch_geometric.nn.GINConv(mlp(hidden_channels, hidden_channels//2))
        
        self.dropout = torch.nn.Dropout(dropout_rate) if dropout_rate > 0 else None
        self.lin = torch.nn.Linear(hidden_channels//2, 1)

        # Apply Xavier initialisation
        if weight_init:
            self.apply(self._init_weights)

    def _init_weights(self, module):
        """Apply Xavier initialisation to the weights of a given module."""
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, GINConv):
            # Initialising weights of the GINConv's MLPs
            for layer in module.nn:
                if isinstance(layer, torch.nn.Linear):
                    torch.nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        torch.nn.init.zeros_(layer.bias)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        if self.bn1:
            x = self.bn1(x)
            x = F.relu(x) # Only if batch normalisation is performed
        #if self.dropout:
        #    x = self.dropout(x)
            
        x = self.conv2(x, edge_index)
        if self.bn2:
            x = self.bn2(x)
            x = F.relu(x)
        #if self.dropout:
        #    x = self.dropout(x)
            
        x = self.conv3(x, edge_index)
        #if self.bn3:
        #    x = self.bn3(x)
        #    x = F.relu(x)
        if self.dropout:
            x = self.dropout(x)
            
        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.lin(x)
        return x

In [3]:
def train(model, loader, optimiser, criterion, device, threshold = 0.5): # added
    model.train()
    total_loss = 0
    correct = 0
    for data in loader:
        data = data.to(device)
        optimiser.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y.view(-1, 1).float())
        loss.backward()
        optimiser.step()
        total_loss += loss.item()
        pred = (out > threshold).float() # added
        correct += pred.eq(data.y.view(-1, 1)).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

def evaluate(model, loader, criterion, device, threshold = 0.5): # added
    model.eval()
    total_loss = 0
    correct = 0
    preds, labels = [], []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y.view(-1, 1).float())
            total_loss += loss.item()
            pred = (out > threshold).float() # added
            correct += pred.eq(data.y.view(-1, 1)).sum().item()
            preds.append(out)
            labels.append(data.y.view(-1, 1))
    preds = torch.cat(preds, dim=0)
    labels = torch.cat(labels, dim=0)
    return total_loss / len(loader), correct / len(loader.dataset), roc_auc_score(labels.cpu(), preds.cpu()), f1_score(labels.cpu(), (preds > threshold).float().cpu()), matthews_corrcoef(labels.cpu(), (preds > threshold).float().cpu()) # added # modified for MCC

class EarlyStopping:
    def __init__(self, patience=10, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0

    def __call__(self, val_score, model):
        score = -val_score
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_score, model)
        elif score < self.best_score - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_score, model)
            self.counter = 0

    def save_checkpoint(self, val_score, model):
        self.best_model = model

In [4]:
def load_datasets_and_models(pytorch_custom_dataset, batch_norm, weight_init, batch_size=32, hidden_channels=64, dropout_rate=0.3, seed=42):
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # split into indices 10 fold cross-validation
    num_samples = len(pytorch_custom_dataset)
    indices = np.arange(num_samples)
    skf_outer = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
    targets = [data.y.item() for data in pytorch_custom_dataset]

    data_loaders = []

    for train_indices, test_indices in skf_outer.split(indices, targets):
        train_dataset = Subset(pytorch_custom_dataset, train_indices)
        test_dataset = Subset(pytorch_custom_dataset, test_indices)

        # Further split train_indices into train and validation sets 
        skf_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)  
        inner_train_indices, val_indices = next(skf_inner.split(train_indices, np.array(targets)[train_indices]))  

        inner_train_dataset = Subset(pytorch_custom_dataset, inner_train_indices)  
        val_dataset = Subset(pytorch_custom_dataset, val_indices)  

        # Create data loaders
        train_loader = DataLoader(inner_train_dataset, batch_size=batch_size, shuffle=False)  
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)  
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        data_loaders.append((train_loader, val_loader, test_loader))  

    # Define model parameters
    num_features = pytorch_custom_dataset.num_features

    # Define models dictionary
    models = {
        'GCN_Optimised': GNNModel(GCNConv, hidden_channels, num_features, batch_norm, weight_init, dropout_rate=dropout_rate),
        'GAT_Optimised': GNNModel(GATConv, hidden_channels, num_features, batch_norm, weight_init, dropout_rate=dropout_rate),
        'GraphSAGE_Optimised': GNNModel(SAGEConv, hidden_channels, num_features, batch_norm, weight_init, dropout_rate=dropout_rate),
        'GIN_Optimised': GINModel(hidden_channels, num_features, batch_norm, weight_init, dropout_rate=dropout_rate)
    }

    return data_loaders, models

def run_model(models, data_loaders, num_runs=5, num_epochs=30, learning_rate=0.0001, 
              use_early_stopping=False, patience=10, delta=0.001):
    
    results = {name: {'accuracy': [], 'f1': [], 'auc': [], 'mcc': [], 'train_loss': [], 'val_loss': []} 
               for name in models.keys()}
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    threshold = 0.5  # For binary classification
    
    for name, model_class in models.items():
        print(f"\nRunning {name} model...")
        for run in range(num_runs):
            print(f"Run {run + 1}/{num_runs}")
            
            for fold, (train_loader, val_loader, test_loader) in enumerate(data_loaders):  
                print(f"Fold {fold + 1}/{len(data_loaders)}")
                start_time = time.time()
                
                # Set random seed
                seed = run
                torch.manual_seed(seed)
                if torch.cuda.is_available():
                    torch.cuda.manual_seed(seed)
                
                model = model_class.to(device)
                for layer in model.children(): ############ Use this all the time!!
                    if hasattr(layer, 'weight') and layer.weight is not None:
                        torch.nn.init.xavier_uniform_(layer.weight)
                    if hasattr(layer, 'bias') and layer.bias is not None:
                        torch.nn.init.zeros_(layer.bias)
                optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)
                criterion = torch.nn.BCEWithLogitsLoss()
                
                best_val_auc = 0
                early_stopping = EarlyStopping(patience=patience, delta=delta)
                for epoch in range(1, num_epochs + 1):
                    train_loss, train_acc = train(model, train_loader, optimiser, criterion, device, threshold=threshold)
                    val_loss, val_acc, val_auc, _, _ = evaluate(model, val_loader, criterion, device, threshold=threshold)  
                    
                    results[name]['train_loss'].append(train_loss)
                    results[name]['val_loss'].append(val_loss)
                    
                    if use_early_stopping:
                        early_stopping(val_auc, model)
                        if early_stopping.early_stop:
                            print(f"Early stopping at epoch {epoch}")
                            break
                
                best_model = early_stopping.best_model if use_early_stopping else model
                
                test_loss, test_acc, test_auc, test_f1, test_mcc = evaluate(best_model, test_loader, criterion, device, threshold=threshold)  
                
                results[name]['accuracy'].append(test_acc)
                results[name]['f1'].append(test_f1)
                results[name]['auc'].append(test_auc)
                results[name]['mcc'].append(test_mcc)
                
                print(f'{name} Fold {fold + 1} Run {run + 1}\n'
                      f'Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}, Test AUC: {test_auc:.4f}, Test MCC: {test_mcc:.4f}\n'
                      f'Val Acc: {val_acc:.4f}, Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}\n'
                      f'Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}')
                end_time = time.time()
                total_time = (end_time - start_time) / 60
                print(f"Fold total time taken: {total_time:.2f} minutes")
            
    return results
