# **Patient-to-Drug Link Prediction: A Walkthrough**
This notebook demonstrates how to perform a **link prediction task** using a **biological knowledge graph**. The goal is to predict whether a specific drug is suitable for a given patient, based on their features and the relationships in the graph. Download the data [here](https://drive.google.com/drive/folders/12fFDPScwcvYjykKm_olKKF_-JXN6d__P?usp=drive_link)

---

## **Objective**
- Learn how to set up a **biological knowledge graph** where nodes represent entities (patients, drugs, etc.) and edges represent relationships (e.g., drug efficacy).
- Build a **Graph Neural Network (GNN)** for link prediction.
- Train the model to predict the presence of edges between specific pairs of nodes (patient-drug pairs).

---

## **1. Introduction to Link Prediction**
### **What is Link Prediction?**
Link prediction is the task of predicting whether a link (edge) exists between two nodes in a graph. It is widely used in:
- **Biology**: Predicting interactions (e.g., protein-protein interactions).
- **Recommender Systems**: Suggesting products or friends.
- **Drug Discovery**: Identifying new drug-target interactions.

### **Biological Knowledge Graph**
In this case:
- **Nodes**: Represent patients and drugs.
- **Edges**: Represent known relationships, such as whether a drug is prescribed for a patient.
- **Features**:
  - **Patients**: Demographic, clinical, or genetic information.
  - **Drugs**: Molecular properties or existing indications.

### **Goal**
Predict whether a drug (node) is suitable for a patient (node) by learning from the structure of the biological knowledge graph.

---


In [None]:
# 1. Install required libraries
!pip install torch torch-geometric pandas psycopg2

# Import required libraries
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
import torch.nn as nn
import pandas as pd
from torch_geometric.data import Data
import numpy as np
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
from google.colab import files

Mounted at /content/drive


---

## **2. Data Setup**
In this section, we will:
1. Construct a **biological knowledge graph** with nodes for patients and drugs.
2. Split the edges into training, validation, and test sets for link prediction.

---

### **Graph Construction**
We create a synthetic graph where:
- **Nodes**:
  - Patients: Represented by patient-specific features (e.g., age, symptoms).
  - Drugs: Represented by drug-specific features (e.g., molecular descriptors).
- **Edges**: Represent known relationships between patients and drugs.

---


In [None]:
# Load patient and prescription data
patients = pd.read_csv("/content/drive/Shareddrives/Bootcamp/Bootcamp 8 - GNN/Module 5/csvs/patients.csv")
prescriptions = pd.read_csv("/content/drive/Shareddrives/Bootcamp/Bootcamp 8 - GNN/Module 5/csvs/prescriptions.csv")
print(patients.head())
print(prescriptions.head())

# Step 1: Preprocess Patient Features
# Calculate age and select relevant columns
patients['age'] = 2024 - pd.to_datetime(patients['anchor_year']).dt.year
patient_features = patients[['subject_id', 'gender', 'age']].copy()

# Encode categorical features (e.g., gender)
patient_features['gender'] = patient_features['gender'].map({'M': 0, 'F': 1})
patient_features_tensor = torch.tensor(patient_features[['age', 'gender']].values, dtype=torch.float)

# Step 2: Preprocess Drug Features
# Extract and one-hot encode drug form
drug_features = prescriptions[['drug', 'form_rx']].drop_duplicates()
drug_features_encoded = pd.get_dummies(drug_features, columns=['form_rx'])
drug_features_tensor = torch.tensor(drug_features_encoded.iloc[:, 1:].values, dtype=torch.float)

# Step 3: Align Dimensions
# Determine the maximum feature length between patients and drugs
max_features = max(patient_features_tensor.size(1), drug_features_tensor.size(1))

# Pad the smaller tensor to match dimensions
if patient_features_tensor.size(1) < max_features:
    padding = max_features - patient_features_tensor.size(1)
    patient_features_tensor = torch.nn.functional.pad(patient_features_tensor, (0, padding))
elif drug_features_tensor.size(1) < max_features:
    padding = max_features - drug_features_tensor.size(1)
    drug_features_tensor = torch.nn.functional.pad(drug_features_tensor, (0, padding))

# Step 4: Create Edges
edges = prescriptions[['subject_id', 'drug']].drop_duplicates()

# Map subject_id to zero-based indices
edges['subject_id'] = edges['subject_id'].map(lambda x: x - 1)

# Map drugs to indices in the drug_features DataFrame
drug_index_map = {drug: idx for idx, drug in enumerate(drug_features['drug'])}
edges['drug'] = edges['drug'].map(drug_index_map)

# Convert edges to PyTorch Geometric format
edge_index = torch.tensor(edges.dropna().values.T, dtype=torch.long)

x = torch.cat([patient_features_tensor, drug_features_tensor], dim=0)

# Total number of nodes
num_patients = patient_features_tensor.size(0)
num_drugs = drug_features_tensor.size(0)
num_nodes = num_patients + num_drugs

# Step 5: Create Edge Index
# Prepare edges by mapping `subject_id` and `drug` to their respective node indices
edges = prescriptions[['subject_id', 'drug']].drop_duplicates()

# Map `subject_id` to zero-based indices
patient_ids = patients['subject_id'].unique()
patient_index_map = {patient_id: idx for idx, patient_id in enumerate(patient_ids)}
edges['subject_id'] = edges['subject_id'].map(patient_index_map)


# Map `drug` to indices, starting after the patients
unique_drugs = edges['drug'].unique()
drug_index_map = {drug: idx + num_patients for idx, drug in enumerate(unique_drugs)}
edges['drug'] = edges['drug'].map(drug_index_map)


# Drop NaN values in the mapping
edges = edges.dropna()

# Convert edges to PyTorch Geometric format
edge_index = torch.tensor(edges.values.T, dtype=torch.long)

# Step 6: Validate Edge Indices
# Ensure all indices in `edge_index` are within the range of the node feature matrix
if edge_index.max() >= num_nodes:
    raise ValueError(
        f"Edge indices out of bounds: max index {edge_index.max()} exceeds total nodes {num_nodes}."
    )

# Step 7: Create the Graph
data = Data(x=x, edge_index=edge_index)

# Display Graph Information
print("Graph created successfully!")
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Feature dimension: {data.x.size(1)}")


   subject_id gender  anchor_age  anchor_year anchor_year_group         dod
0    10014729      F          21         2125       2011 - 2013         NaN
1    10003400      F          72         2134       2011 - 2013  2137-09-02
2    10002428      F          80         2155       2011 - 2013         NaN
3    10032725      F          38         2143       2011 - 2013  2143-03-30
4    10027445      F          48         2142       2011 - 2013  2146-02-09
   subject_id   hadm_id  pharmacy_id poe_id  poe_seq order_provider_id  \
0    10027602  28166872     27168639    NaN      NaN               NaN   
1    10027602  28166872     40720238    NaN      NaN               NaN   
2    10027602  28166872     62845687    NaN      NaN               NaN   
3    10027602  28166872     24340150    NaN      NaN               NaN   
4    10027602  28166872     14435820    NaN      NaN               NaN   

             starttime stoptime drug_type              drug  ...  gsn ndc  \
0  2201-10-30 12:00:00

### **Edge Splitting for Link Prediction**
For link prediction, we split the edges into:
- **Training edges**: Used to train the model.
- **Validation edges**: Used to tune hyperparameters.
- **Test edges**: Used to evaluate the model's performance.

---


In [None]:
# Split edges into train, validation, and test sets
def split_edges(edge_index, val_ratio=0.1, test_ratio=0.1):
    # Get all edges
    edges = edge_index.t().numpy()  # Shape: (num_edges, 2)
    num_edges = edges.shape[0]

    # Shuffle edges
    indices = np.random.permutation(num_edges)
    edges = edges[indices]

    # Calculate split sizes
    num_val = int(num_edges * val_ratio)
    num_test = int(num_edges * test_ratio)
    num_train = num_edges - num_val - num_test

    # Split edges
    train_edges = edges[:num_train]
    val_edges = edges[num_train:num_train + num_val]
    test_edges = edges[num_train + num_val:]

    # Convert back to PyTorch tensors
    train_edges = torch.tensor(train_edges, dtype=torch.long).t()
    val_edges = torch.tensor(val_edges, dtype=torch.long).t()
    test_edges = torch.tensor(test_edges, dtype=torch.long).t()

    return train_edges, val_edges, test_edges

val_ratio = 0.1
test_ratio = 0.1
train_edges, val_edges, test_edges = split_edges(data.edge_index, val_ratio, test_ratio)


print("Train/Test split complete.")
print(f"Train edges: {train_edges.size(1)}")
print(f"Validation edges: {val_edges.size(1)}")
print(f"Test edges: {test_edges.size(1)}")


Train/Test split complete.
Train edges: 4457
Validation edges: 557
Test edges: 557


---

## **3. Graph Neural Network for Link Prediction**
We will use a GNN-based encoder to generate node embeddings and a decoder to predict the existence of edges.

### **Model Architecture**
1. **Encoder**: A GNN (e.g., GCN or GraphSAGE) to learn node embeddings.
2. **Decoder**: A dot product layer to predict edge existence.

---


In [None]:
class LinkPredictor(nn.Module):
    def __init__(self):
        super(LinkPredictor, self).__init__()

    def forward(self, z, edge_index):
        # Extract embeddings for the source and target nodes of each edge
        row, col = edge_index
        # Compute the dot product between the embeddings
        scores = (z[row] * z[col]).sum(dim=-1)
        return scores
# 3. Define GNN Encoder
class GNNEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GNNEncoder, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# Define the encoder and link predictor
input_dim = data.x.size(1)
hidden_dim = 64

encoder = GNNEncoder(input_dim, hidden_dim)
link_predictor = LinkPredictor()

print("Model initialized.")


Model initialized.


---

## **4. Training and Evaluation**
### **Training**
We optimize the model using a binary cross-entropy loss for the link prediction task.

### **Evaluation**
Evaluate the model on the validation and test sets using metrics such as accuracy or ROC-AUC.

---


In [None]:
# 4. Training and evaluation
from torch_geometric.utils import negative_sampling
# Function to sample negative edges
def sample_negative_edges(edge_index, num_nodes, num_samples):
    neg_edge_index = negative_sampling(
        edge_index=edge_index,
        num_nodes=num_nodes,
        num_neg_samples=num_samples
    )
    return neg_edge_index
# Optimizer
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(link_predictor.parameters()), lr=0.01
)

# Training Function
def train():
    encoder.train()
    link_predictor.train()
    optimizer.zero_grad()

    # Encode node embeddings
    z = encoder(data.x, train_edges)

    # Sample negative edges
    num_train_edges = train_edges.size(1)
    neg_edges = sample_negative_edges(train_edges, data.x.size(0), num_train_edges)

    # Predict edge probabilities
    pos_pred = link_predictor(z, train_edges)
    neg_pred = link_predictor(z, neg_edges)

    # Create labels for positive and negative edges
    pos_label = torch.ones(pos_pred.size(0), dtype=torch.float)
    neg_label = torch.zeros(neg_pred.size(0), dtype=torch.float)

    # Compute binary cross-entropy loss
    loss = F.binary_cross_entropy_with_logits(
        torch.cat([pos_pred, neg_pred]), torch.cat([pos_label, neg_label])
    )
    loss.backward()
    optimizer.step()
    return loss.item()

# Evaluation Function
from sklearn.metrics import roc_auc_score

def evaluate(edge_index, neg_edge_index):
    encoder.eval()
    link_predictor.eval()

    # Encode node embeddings
    z = encoder(data.x, edge_index)

    # Predict edge probabilities
    pos_pred = link_predictor(z, edge_index)
    neg_pred = link_predictor(z, neg_edge_index)

    # Create labels
    pos_label = torch.ones(pos_pred.size(0), dtype=torch.float)
    neg_label = torch.zeros(neg_pred.size(0), dtype=torch.float)

    # Compute ROC-AUC
    pred = torch.cat([pos_pred, neg_pred]).sigmoid().cpu().detach().numpy()
    label = torch.cat([pos_label, neg_label]).cpu().detach().numpy()
    auc = roc_auc_score(label, pred)
    return auc

# Training Loop
for epoch in range(50):
    loss = train()

    # Validation AUC
    num_val_edges = val_edges.size(1)
    val_neg_edges = sample_negative_edges(val_edges, data.x.size(0), num_val_edges)
    val_auc = evaluate(val_edges, val_neg_edges)

    print(f"Epoch {epoch + 1}, Loss: {loss:.4f}, Validation AUC: {val_auc:.4f}")

# Test AUC
num_test_edges = test_edges.size(1)
test_neg_edges = sample_negative_edges(test_edges, data.x.size(0), num_test_edges)
test_auc = evaluate(test_edges, test_neg_edges)
print(f"Test ROC-AUC: {test_auc:.4f}")


Epoch 1, Loss: 2062.0457, Validation AUC: 0.6517
Epoch 2, Loss: 717.3783, Validation AUC: 0.6499
Epoch 3, Loss: 423.7495, Validation AUC: 0.6391
Epoch 4, Loss: 696.3655, Validation AUC: 0.6364
Epoch 5, Loss: 818.4996, Validation AUC: 0.6697
Epoch 6, Loss: 586.0899, Validation AUC: 0.2123
Epoch 7, Loss: 283.5254, Validation AUC: 0.9120
Epoch 8, Loss: 177.8408, Validation AUC: 0.9084
Epoch 9, Loss: 195.6187, Validation AUC: 0.8860
Epoch 10, Loss: 245.7602, Validation AUC: 0.8824
Epoch 11, Loss: 243.9477, Validation AUC: 0.0380
Epoch 12, Loss: 232.1940, Validation AUC: 0.8896
Epoch 13, Loss: 198.2712, Validation AUC: 0.8923
Epoch 14, Loss: 188.8858, Validation AUC: 0.8815
Epoch 15, Loss: 160.1725, Validation AUC: 0.8950
Epoch 16, Loss: 125.3144, Validation AUC: 0.0467
Epoch 17, Loss: 97.2902, Validation AUC: 0.7327
Epoch 18, Loss: 75.6229, Validation AUC: 0.8311
Epoch 19, Loss: 62.5485, Validation AUC: 0.0410
Epoch 20, Loss: 57.1136, Validation AUC: 0.8959
Epoch 21, Loss: 60.3896, Validat

---

## **5. Results and Conclusion**
### **Test Performance**
Evaluate the model on the test set.

### **Insights**
The model demonstrates how GNNs can effectively predict links (e.g., patient-drug associations) in a biological knowledge graph.

---


In [None]:
# Test the model
num_test_edges = test_edges.size(1)  # Number of positive test edges
test_neg_edges = sample_negative_edges(test_edges, data.x.size(0), num_test_edges)

# Evaluate on test set
test_auc = evaluate(test_edges, test_neg_edges)

print(f"Test ROC-AUC: {test_auc:.4f}")


Test ROC-AUC: 0.8245


# **Check-in Questions:**

Why is the SAGEConv layer used in the model, and how does it aggregate node information?

What role do node embeddings play in the context of link prediction?

How does the model leverage graph structure (nodes and edges) during training?