In [2]:
# Install PyTorch Geometric
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.0+cu118.html


Looking in links: https://data.pyg.org/whl/torch-2.0.0+cu118.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_scatter-2.1.2%2Bpt20cu118-cp311-cp311-linux_x86_64.whl (10.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m91.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_sparse-0.6.18%2Bpt20cu118-cp311-cp311-linux_x86_64.whl (4.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_cluster-1.6.3%2Bpt20cu118-cp311-cp311-linux_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m112.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_

In [4]:

# ===================================
# Step 1: Import libraries
# ===================================
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
import pickle
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    precision_score,
    recall_score,
    f1_score
)
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ===================================
# Step 2: Load features and similarity matrices
# ===================================
# Load drug features
with open('graph2seq_features.pkl', 'rb') as f:
    graph2seq_features = pickle.load(f)

drug_ids = list(graph2seq_features.keys())
id_to_idx = {drug_id: idx for idx, drug_id in enumerate(drug_ids)}

# Load similarity matrices
target_sim = np.load('target_similarity_matrix.npy')  # assume pre-saved
enzyme_sim = np.load('enzyme_similarity_matrix.npy')
smiles_sim = np.load('smiles_similarity_matrix.npy')

# ===================================
# Step 3: Build node features
# ===================================
x = torch.tensor([graph2seq_features[drug_id] for drug_id in drug_ids], dtype=torch.float32)

# ===================================
# Step 4: Build edges with real weights
# ===================================
edge_index = []
edge_attr = []

n = len(drug_ids)

threshold = 0.5  # Only keep high similarities (tuneable)

for i in range(n):
    for j in range(i+1, n):
        sim_score = (target_sim[i,j] + enzyme_sim[i,j] + smiles_sim[i,j]) / 3  # average multi-view
        if sim_score > threshold:
            edge_index.append([i, j])
            edge_index.append([j, i])
            edge_attr.append([sim_score])
            edge_attr.append([sim_score])

edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_attr, dtype=torch.float32)

# ===================================
# Step 5: Load interaction labels
# ===================================
events = pd.read_csv('events_extract.csv')
id1_idx = events['id1'].map(id_to_idx).values
id2_idx = events['id2'].map(id_to_idx).values
labels = torch.tensor(events['label'].values, dtype=torch.long)

# ===================================
# Step 6: GAT Model Definition
# ===================================

class GNNEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, heads=4):
        super(GNNEncoder, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim, heads=heads, edge_dim=1)
        self.conv2 = GATConv(hidden_dim*heads, hidden_dim, heads=1, edge_dim=1)

    def forward(self, x, edge_index, edge_attr):
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = torch.relu(x)
        return x

class DDI_GNN(nn.Module):
    def __init__(self, encoder, hidden_dim, num_classes):
        super(DDI_GNN, self).__init__()
        self.encoder = encoder
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim*2, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x, edge_index, edge_attr, idx1, idx2):
        z = self.encoder(x, edge_index, edge_attr)
        emb1 = z[idx1]
        emb2 = z[idx2]
        out = torch.cat([emb1, emb2], dim=-1)
        return self.mlp(out)

# ===================================
# Step 7: Training Loop
# ===================================

model = DDI_GNN(GNNEncoder(x.shape[1], 128), hidden_dim=128, num_classes=labels.max().item()+1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

epochs = 30
batch_size = 128

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()

    idx = torch.randperm(len(labels))
    id1_idx_batch = id1_idx[idx]
    id2_idx_batch = id2_idx[idx]
    labels_batch = labels[idx]

    preds = model(x.to(device), edge_index.to(device), edge_attr.to(device), id1_idx_batch, id2_idx_batch)
    loss = criterion(preds, labels_batch.to(device))
    loss.backward()
    optimizer.step()

    acc = (preds.argmax(dim=1) == labels_batch.to(device)).float().mean()
    print(f"Epoch {epoch+1} | Loss: {loss.item():.4f} | Accuracy: {acc.item():.4f}")



  x = torch.tensor([graph2seq_features[drug_id] for drug_id in drug_ids], dtype=torch.float32)


Epoch 1 | Loss: 4.0073 | Accuracy: 0.0008
Epoch 2 | Loss: 3.9381 | Accuracy: 0.2364
Epoch 3 | Loss: 3.8391 | Accuracy: 0.4494
Epoch 4 | Loss: 3.6814 | Accuracy: 0.4746
Epoch 5 | Loss: 3.4424 | Accuracy: 0.4768
Epoch 6 | Loss: 3.1185 | Accuracy: 0.4771
Epoch 7 | Loss: 2.7877 | Accuracy: 0.4771
Epoch 8 | Loss: 2.6295 | Accuracy: 0.4771
Epoch 9 | Loss: 2.6246 | Accuracy: 0.4771
Epoch 10 | Loss: 2.5480 | Accuracy: 0.4771
Epoch 11 | Loss: 2.3927 | Accuracy: 0.4771
Epoch 12 | Loss: 2.2588 | Accuracy: 0.4771
Epoch 13 | Loss: 2.2209 | Accuracy: 0.4771
Epoch 14 | Loss: 2.2394 | Accuracy: 0.4518
Epoch 15 | Loss: 2.2537 | Accuracy: 0.1735
Epoch 16 | Loss: 2.2147 | Accuracy: 0.1762
Epoch 17 | Loss: 2.1329 | Accuracy: 0.3789
Epoch 18 | Loss: 2.0549 | Accuracy: 0.4771
Epoch 19 | Loss: 2.0115 | Accuracy: 0.4771
Epoch 20 | Loss: 1.9998 | Accuracy: 0.4771
Epoch 21 | Loss: 1.9943 | Accuracy: 0.4771
Epoch 22 | Loss: 1.9786 | Accuracy: 0.4771
Epoch 23 | Loss: 1.9558 | Accuracy: 0.4771
Epoch 24 | Loss: 1.9

In [6]:
model.eval()

all_probs = []
all_preds = []
all_labels = []

with torch.no_grad():
    # Forward pass
    preds = model(
        x.to(device),
        edge_index.to(device),
        edge_attr.to(device),
        torch.tensor(id1_idx, dtype=torch.long).to(device),
        torch.tensor(id2_idx, dtype=torch.long).to(device)
    )

    probs = F.softmax(preds, dim=1).cpu().numpy()
    labels_np = labels.cpu().numpy()
    preds_np = np.argmax(probs, axis=1)

    all_probs.extend(probs)
    all_preds.extend(preds_np)
    all_labels.extend(labels_np)

all_probs = np.array(all_probs)
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Determine number of classes
num_classes = all_probs.shape[1]

# ----- ROC AUC and AUPR -----
roc_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
aupr = average_precision_score(all_labels, all_probs, average='macro')

# ----- Precision, Recall, F1 -----
precision = precision_score(all_labels, all_preds, average='macro')
recall = recall_score(all_labels, all_preds, average='macro')
f1 = f1_score(all_labels, all_preds, average='macro')

# ----- Print Metrics -----
print("Evaluation Metrics:")
print(f"ROC AUC     : {roc_auc:.4f}")
print(f"AUPR        : {aupr:.4f}")
print(f"Precision   : {precision:.4f}")
print(f"Recall      : {recall:.4f}")
print(f"F1 Score    : {f1:.4f}")

Evaluation Metrics:
ROC AUC     : 0.6025
AUPR        : 0.0251
Precision   : 0.0085
Recall      : 0.0179
F1 Score    : 0.0115


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
