In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

# GCN layer as per the authors' style
class GCNLayer(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.2):
        super(GCNLayer, self).__init__()
        self.gcn = GCNConv(in_channels, out_channels)
        self.dropout = dropout
        self.batch_norm = nn.BatchNorm1d(out_channels)

    def forward(self, x, edge_index):
        x = self.gcn(x, edge_index)
        x = self.batch_norm(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


In [2]:
class MoleculeGNN(nn.Module):
    def __init__(self, num_features, hidden_dim=128, dropout=0.2):
        super(MoleculeGNN, self).__init__()
        self.layer1 = GCNLayer(num_features, hidden_dim, dropout)
        self.layer2 = GCNLayer(hidden_dim, hidden_dim * 2, dropout)
        self.layer3 = GCNLayer(hidden_dim * 2, hidden_dim * 4, dropout)

    def forward(self, x, edge_index, batch):
        x = self.layer1(x, edge_index)
        x = self.layer2(x, edge_index)
        x = self.layer3(x, edge_index)
        # Global mean pooling
        x = global_mean_pool(x, batch)
        return x

class ProteinGNN(nn.Module):
    def __init__(self, num_features, hidden_dim=128, dropout=0.2):
        super(ProteinGNN, self).__init__()
        self.layer1 = GCNLayer(num_features, hidden_dim, dropout)
        self.layer2 = GCNLayer(hidden_dim, hidden_dim * 2, dropout)
        self.layer3 = GCNLayer(hidden_dim * 2, hidden_dim * 4, dropout)

    def forward(self, x, edge_index, batch):
        x = self.layer1(x, edge_index)
        x = self.layer2(x, edge_index)
        x = self.layer3(x, edge_index)
        # Global mean pooling
        x = global_mean_pool(x, batch)
        return x


In [3]:
class DTA_GNN(nn.Module):
    def __init__(self, mol_input_dim, prot_input_dim, hidden_dim=128, dropout=0.2):
        super(DTA_GNN, self).__init__()
        self.mol_gnn = MoleculeGNN(mol_input_dim, hidden_dim, dropout)
        self.prot_gnn = ProteinGNN(prot_input_dim, hidden_dim, dropout)
        # Fully connected layers after concatenation
        self.fc1 = nn.Linear(hidden_dim * 8, hidden_dim * 4)
        self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim * 2)
        self.out = nn.Linear(hidden_dim * 2, 1)
        self.dropout = dropout

    def forward(self, mol_data, prot_data):
        # Molecule GNN
        mol_x = self.mol_gnn(mol_data.x, mol_data.edge_index, mol_data.batch)
        # Protein GNN
        prot_x = self.prot_gnn(prot_data.x, prot_data.edge_index, prot_data.batch)
        # Concatenate molecule and protein features
        x = torch.cat((mol_x, prot_x), dim=1)
        # Fully connected layers with dropout and activation
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        # Output layer
        out = self.out(x)
        return out


In [4]:
from lifelines.utils import concordance_index
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr

def calculate_metrics(y_true, y_pred):
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    # Concordance Index
    ci = concordance_index(y_true, y_pred)
    # Mean Squared Error
    mse = mean_squared_error(y_true, y_pred)
    # Pearson Correlation Coefficient
    pearson_corr, _ = pearsonr(y_true, y_pred)
    return ci, mse, pearson_corr


ModuleNotFoundError: No module named 'lifelines'