In [1]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool as gep

# GNN Model definition for molecule and protein graphs
class GNNNet(torch.nn.Module):
    def __init__(self, n_output=1, num_features_pro=54, num_features_mol=78, output_dim=128, dropout=0.2):
        super(GNNNet, self).__init__()

        print('GNNNet Loaded')
        # Molecule GNN layers
        self.mol_conv1 = GCNConv(num_features_mol, num_features_mol)
        self.mol_conv2 = GCNConv(num_features_mol, num_features_mol * 2)
        self.mol_conv3 = GCNConv(num_features_mol * 2, num_features_mol * 4)
        self.mol_fc_g1 = torch.nn.Linear(num_features_mol * 4, 1024)
        self.mol_fc_g2 = torch.nn.Linear(1024, output_dim)

        # Protein GNN layers
        self.pro_conv1 = GCNConv(num_features_pro, num_features_pro)
        self.pro_conv2 = GCNConv(num_features_pro * 2, num_features_pro * 4)
        self.pro_fc_g1 = torch.nn.Linear(num_features_pro * 4, 1024)
        self.pro_fc_g2 = torch.nn.Linear(1024, output_dim)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        # Combined dense layers
        self.fc1 = nn.Linear(2 * output_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, n_output)

    def forward(self, data_mol, data_pro):
        # Molecule forward pass
        mol_x, mol_edge_index, mol_batch = data_mol.x, data_mol.edge_index, data_mol.batch
        mol_x = self.mol_conv1(mol_x, mol_edge_index)
        mol_x = self.relu(mol_x)
        mol_x = self.mol_conv2(mol_x, mol_edge_index)
        mol_x = self.relu(mol_x)
        mol_x = self.mol_conv3(mol_x, mol_edge_index)
        mol_x = gep(mol_x, mol_batch)  # global pooling
        mol_x = self.relu(self.mol_fc_g1(mol_x))
        mol_x = self.dropout(mol_x)
        mol_x = self.mol_fc_g2(mol_x)

        # Protein forward pass
        pro_x, pro_edge_index, pro_batch = data_pro.x, data_pro.edge_index, data_pro.batch
        pro_x = self.pro_conv1(pro_x, pro_edge_index)
        pro_x = self.relu(pro_x)
        pro_x = self.pro_conv2(pro_x, pro_edge_index)
        pro_x = self.relu(pro_x)
        pro_x = gep(pro_x, pro_batch)  # global pooling
        pro_x = self.relu(self.pro_fc_g1(pro_x))
        pro_x = self.dropout(pro_x)
        pro_x = self.pro_fc_g2(pro_x)

        # Concatenate molecule and protein features
        xc = torch.cat((mol_x, pro_x), dim=1)
        xc = self.fc1(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)

        return out


In [None]:
import os
import torch
import pickle
import pandas as pd
from torch_geometric.data import Data

def load_graph(path):
    # Load graphs from .pkl for molecules and .pt for proteins
    with open(path, 'rb') as f:
        return pickle.load(f)

def prepare_dataset(filtered_dataset, molecule_graph_dir, protein_graph_dir):
    dataset = []
    
    for index, row in filtered_dataset.iterrows():
        # Load molecule graph based on Drug_ID
        mol_graph_path = os.path.join(molecule_graph_dir, f"{row['Drug_ID']}_graph.pkl")
        mol_graph = load_graph(mol_graph_path)

        # Load protein graph based on Target_ID
        pro_graph_path = os.path.join(protein_graph_dir, f"{row['Target_ID']}_graph.pt")
        pro_graph = torch.load(pro_graph_path)
        
        # Load target (affinity value)
        target = torch.tensor([row['Y']], dtype=torch.float)
        
        # Append tuple (mol_graph, pro_graph, target) to dataset
        dataset.append((mol_graph, pro_graph, target))
    
    return dataset

# Example usage for dataset preparation
molecule_graph_dir = 'molecule_graphs/'  # Directory where molecule graphs are stored
protein_graph_dir = 'ProteinGraphs/'  # Directory where protein graphs are stored
filtered_dataset_path = 'filtered_KibaDataSet.csv'  # Path to the filtered dataset CSV

# Load filtered dataset CSV
filtered_dataset = pd.read_csv(filtered_dataset_path)

# Prepare the dataset with molecule, protein graphs, and affinity scores
prepared_dataset = prepare_dataset(filtered_dataset, molecule_graph_dir, protein_graph_dir)

#


  pro_graph = torch.load(pro_graph_path)


In [2]:
import numpy as np
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr

def get_mse(labels, preds):
    return mean_squared_error(labels, preds)

def get_pearson(labels, preds):
    return pearsonr(labels, preds)[0]

def get_ci(labels, preds):
    # Concordance Index (CI) implementation
    n = 0
    h_sum = 0
    for i in range(len(labels)):
        for j in range(i + 1, len(labels)):
            if labels[i] != labels[j]:
                n += 1
                if (preds[i] < preds[j] and labels[i] < labels[j]) or (preds[i] > preds[j] and labels[i] > labels[j]):
                    h_sum += 1
                elif preds[i] == preds[j]:
                    h_sum += 0.5
    return h_sum / n if n > 0 else 0.5


In [3]:
from torch_geometric.data import DataLoader
from sklearn.model_selection import KFold
import torch.optim as optim
from torch.nn import MSELoss

def train_5fold_cross_validation(data, molecule_graphs, protein_graphs, num_epochs=1000, n_splits=5, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Print if the model is on GPU or CPU
    if torch.cuda.is_available():
        print("Model is running on GPU.")
    else:
        print("Model is running on CPU.")

    
    kfold = KFold(n_splits=n_splits, shuffle=True)
    
    # Prepare dataset and dataloaders
    results = []
    loss_fn = MSELoss()
    
    for fold, (train_idx, test_idx) in enumerate(kfold.split(data)):
        print(f'Fold {fold + 1}/{n_splits}')
        
        train_data = data.iloc[train_idx]
        test_data = data.iloc[test_idx]
        
        train_loader = DataLoader(prepare_dataset(train_data, molecule_graphs, protein_graphs), batch_size=32, shuffle=True)
        test_loader = DataLoader(prepare_dataset(test_data, molecule_graphs, protein_graphs), batch_size=32, shuffle=False)

        # Initialize model and optimizer
        model = GNNNet().to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)

        for epoch in range(num_epochs):
            model.train()
            for batch in train_loader:
                mol_data, pro_data, target = batch
                optimizer.zero_grad()
                output = model(mol_data.to(device), pro_data.to(device))
                loss = loss_fn(output, target.to(device))
                loss.backward()
                optimizer.step()

            print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item()}")

        # Evaluate on test set
        model.eval()
        total_preds, total_labels = [], []
        with torch.no_grad():
            for batch in test_loader:
                mol_data, pro_data, target = batch
                output = model(mol_data.to(device), pro_data.to(device))
                total_preds.append(output.cpu().numpy())
                total_labels.append(target.cpu().numpy())

        mse = get_mse(total_labels, total_preds)
        ci = get_ci(total_labels, total_preds)
        pearson = get_pearson(total_labels, total_preds)
        print(f"Fold {fold+1} - MSE: {mse}, CI: {ci}, Pearson: {pearson}")

        # Save the results for this fold
        results.append((mse, ci, pearson))
    
    return results


In [4]:
import os
import torch
import pickle
from torch_geometric.data import Data

def load_graph(path):
    with open(path, 'rb') as f:
        return pickle.load(f)
    print("graph is loaded ")

def prepare_dataset(data, molecule_graphs, protein_graphs):
    dataset = []
    
    for index, row in data.iterrows():
        mol_graph_path = os.path.join(molecule_graphs, f"{row['Drug_ID']}_graph.pkl")
        pro_graph_path = os.path.join(protein_graphs, f"{row['Target_ID']}_graph.pt")
        
        mol_graph = load_graph(mol_graph_path)
        pro_graph = torch.load(pro_graph_path)
        target = torch.tensor([row['Y']], dtype=torch.float)
        
        dataset.append((mol_graph, pro_graph, target))

    print("Dataset is ready")
    
    return dataset


In [None]:
# Example usage
molecule_graphs = 'molecule_graphs/'
protein_graphs = 'ProteinGraphs/'
filtered_dataset_path = 'filtered_KibaDataSet.csv'

# Load filtered dataset
import pandas as pd
data = pd.read_csv(filtered_dataset_path)

# Run 5-fold cross-validation training
results = train_5fold_cross_validation(data, molecule_graphs, protein_graphs)


Model is running on GPU.
Fold 1/5


  pro_graph = torch.load(pro_graph_path)
