In [None]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.0+cu118.html

# ===================================
# Step 1: Import libraries
# ===================================
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
import pickle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ===================================
# Step 2: Load features and similarity matrices
# ===================================
# Load drug features
with open('graph2seq_features.pkl', 'rb') as f:
    graph2seq_features = pickle.load(f)

drug_ids = list(graph2seq_features.keys())
id_to_idx = {drug_id: idx for idx, drug_id in enumerate(drug_ids)}

# Load similarity matrices
target_sim = np.load('target_similarity_matrix.npy')  # assume pre-saved
enzyme_sim = np.load('enzyme_similarity_matrix.npy')
smiles_sim = np.load('smiles_similarity_matrix.npy')

# ===================================
# Step 3: Build node features
# ===================================
x = torch.tensor([graph2seq_features[drug_id] for drug_id in drug_ids], dtype=torch.float32)

# ===================================
# Step 4: Build edges with real weights
# ===================================
edge_index = []
edge_attr = []

n = len(drug_ids)

threshold = 0.5  # Only keep high similarities (tuneable)

for i in range(n):
    for j in range(i+1, n):
        sim_score = (target_sim[i,j] + enzyme_sim[i,j] + smiles_sim[i,j]) / 3  # average multi-view
        if sim_score > threshold:
            edge_index.append([i, j])
            edge_index.append([j, i])
            edge_attr.append([sim_score])
            edge_attr.append([sim_score])

edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_attr, dtype=torch.float32)

# ===================================
# Step 5: Load interaction labels
# ===================================
events = pd.read_csv('events_extract.csv')
id1_idx = events['id1'].map(id_to_idx).values
id2_idx = events['id2'].map(id_to_idx).values
labels = torch.tensor(events['label'].values, dtype=torch.long)

# ===================================
# Step 6: GAT Model Definition
# ===================================

class GNNEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, heads=4):
        super(GNNEncoder, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim, heads=heads, edge_dim=1)
        self.conv2 = GATConv(hidden_dim*heads, hidden_dim, heads=1, edge_dim=1)

    def forward(self, x, edge_index, edge_attr):
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = torch.relu(x)
        return x

class DDI_GNN(nn.Module):
    def __init__(self, encoder, hidden_dim, num_classes):
        super(DDI_GNN, self).__init__()
        self.encoder = encoder
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim*2, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x, edge_index, edge_attr, idx1, idx2):
        z = self.encoder(x, edge_index, edge_attr)
        emb1 = z[idx1]
        emb2 = z[idx2]
        out = torch.cat([emb1, emb2], dim=-1)
        return self.mlp(out)

# ===================================
# Step 7: Training Loop
# ===================================

model = DDI_GNN(GNNEncoder(x.shape[1], 128), hidden_dim=128, num_classes=labels.max().item()+1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

epochs = 30
batch_size = 128

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()

    idx = torch.randperm(len(labels))
    id1_idx_batch = id1_idx[idx]
    id2_idx_batch = id2_idx[idx]
    labels_batch = labels[idx]

    preds = model(x.to(device), edge_index.to(device), edge_attr.to(device), id1_idx_batch, id2_idx_batch)
    loss = criterion(preds, labels_batch.to(device))
    loss.backward()
    optimizer.step()

    acc = (preds.argmax(dim=1) == labels_batch.to(device)).float().mean()
    print(f"Epoch {epoch+1} | Loss: {loss.item():.4f} | Accuracy: {acc.item():.4f}")



Epoch 1 | Loss: 4.0515 | Accuracy: 0.0051
Epoch 2 | Loss: 3.9764 | Accuracy: 0.1059
Epoch 3 | Loss: 3.8749 | Accuracy: 0.4391
Epoch 4 | Loss: 3.7195 | Accuracy: 0.4727
Epoch 5 | Loss: 3.4906 | Accuracy: 0.4765
Epoch 6 | Loss: 3.1784 | Accuracy: 0.4771
Epoch 7 | Loss: 2.8204 | Accuracy: 0.4771
Epoch 8 | Loss: 2.5304 | Accuracy: 0.4771
Epoch 9 | Loss: 2.4212 | Accuracy: 0.4771
Epoch 10 | Loss: 2.4019 | Accuracy: 0.4771
Epoch 11 | Loss: 2.3470 | Accuracy: 0.4771
Epoch 12 | Loss: 2.3156 | Accuracy: 0.4771
Epoch 13 | Loss: 2.3771 | Accuracy: 0.4734
Epoch 14 | Loss: 2.3667 | Accuracy: 0.4511
Epoch 15 | Loss: 2.2795 | Accuracy: 0.4774
Epoch 16 | Loss: 2.1973 | Accuracy: 0.4771
Epoch 17 | Loss: 2.1541 | Accuracy: 0.4771
Epoch 18 | Loss: 2.1234 | Accuracy: 0.4771
Epoch 19 | Loss: 2.0877 | Accuracy: 0.4771
Epoch 20 | Loss: 2.0548 | Accuracy: 0.4771
Epoch 21 | Loss: 2.0328 | Accuracy: 0.4771
Epoch 22 | Loss: 2.0199 | Accuracy: 0.4771
Epoch 23 | Loss: 2.0087 | Accuracy: 0.4771
Epoch 24 | Loss: 1.9