## GOAL: predict whether an edge (interaction) between two proteins should be included in a pathway (label 1) or not (label 0)

In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import EdgeConv, NNConv
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split

In [2]:
import pandas as pd
import numpy as np
import os
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score

In [3]:
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs available: {num_gpus}")

Number of GPUs available: 2


In [4]:
# set device to CUDA
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Data Loading:

**Pytorch Geometric `DataLoader` object from saved data:**

In [5]:
union_ppi = pd.read_csv('processed-data/union_ppi.txt', sep='\t', header=None)
unique_nodes = set(union_ppi[0].tolist() + union_ppi[1].tolist())
label_id_map = {label: idx for idx, label in enumerate(sorted(unique_nodes))}
num_nodes = len(label_id_map)
print(f"Total unique nodes: {num_nodes}")

Total unique nodes: 17407


In [6]:
data_list = torch.load('dataset.pt', weights_only=False)

### Class Weight Calculation:

- Calculate the weight for the positive class to handle class imbalance
- needed since the number of negative samples (non-selected edges) is much higher than positive samples (selected edges)
- weighting the **loss function** helps the model to learn from the minority class effectively

In [7]:
def calculate_class_weights(data_list):
    # concat all labels into single tensor
    y = torch.cat([data.y for data in data_list]).cpu().numpy()
    # count the number of pos and negative samples
    num_positive = y.sum()
    num_negative = len(y) - num_positive
    # calculate the weight --> ratio of negative samples to positive samples
    # so that loss function balances the contribution of both classes
    pos_weight = torch.tensor([num_negative / num_positive]).to(device)
    return pos_weight

pos_weight = calculate_class_weights(data_list)
print(f"Positive class weight: {pos_weight.item()}")

Positive class weight: 14725.188571936373


### Train-test split:

In [8]:
train_ratio = 0.8
train_size = int(train_ratio * len(data_list))
test_size = len(data_list) - train_size

In [9]:
train_dataset, test_dataset = random_split(data_list, [train_size, test_size])

### Edge Sampling

- technique that oversamples the positive edges to address class imbalance
- duplicate positive edges in dataset to increase representation during training - so that the model sees more examples
- oersampling helps the model to better learn the characteristics of the minority class (selected edges), improving its ability to classify them correctly

In [10]:
def edge_sampling(data_list, sampling_ratio=0.5):
    """
    Oversample edges from the minority class (positive edges).
    Args:
        data_list: List of PyG data objects.
        sampling_ratio: Ratio of minority class edges to add.
    Returns:
        Augmented data_list with oversampled positive edges.
    """
    augmented_data_list = []
    for data in data_list:
        y = data.y.cpu().numpy()
        # get the indices of positive and negtive edges
        positive_indices = np.where(y == 1)[0]
        negative_indices = np.where(y == 0)[0]
        # oversample positive edges
        num_positive = len(positive_indices)
        # fraction of +ve edges are randomly sampled with replacement
        num_samples = int(sampling_ratio * num_positive)
        # stores the indices of the sampled positive edges
        sampled_indices = np.random.choice(positive_indices, num_samples, replace=True)
        # connectivity information (source and target nodes) for the sampled edges:
        sampled_edge_index = data.edge_index[:, sampled_indices]
        # edge features for sampled edges:
        sampled_edge_attr = data.edge_attr[sampled_indices]
        # labels for sampled edges:
        sampled_y = data.y[sampled_indices]
        # connectivity information is updated by concatenating the original edges and the sampled edge:
        data.edge_index = torch.cat([data.edge_index, sampled_edge_index], dim=1)
        # edge features are updated by concatenating the original features and the sampled features:
        data.edge_attr = torch.cat([data.edge_attr, sampled_edge_attr], dim=0)
        # labels are updated by concatenating the original labels and the sampled label:
        data.y = torch.cat([data.y, sampled_y], dim=0)
        augmented_data_list.append(data)  
    return augmented_data_list

In [11]:
# apply edge sampling to the training dataset
train_dataset = edge_sampling(train_dataset, sampling_ratio=0.5)

### Make DataLoader:

- for batching and shuffling data
- feeds data into model during training and eval

In [12]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

### Model Definition: using `EdgeConv` and `NNNConv` layers

- define GNN for edge classification 
- architecture designed to capture both node + edge features - needed for accurately classifying edges in PPI network
- use mlps to transform node + edge features into higher dim representations to capture pattersn in data

**Input Graph:**
- Node features (x): `[num_nodes, node_feat_dim]`
- Edge index (edge_index): `[2, num_edges]`
- Edge features (edge_attr): `[num_edges, edge_feat_dim]`

**Layers:**
- `EdgeConv` captures local patterns in the graph by considering the relationships between a node and its neighbors.

In [13]:
class EdgeClassificationGNN(nn.Module):
    def __init__(self, node_feat_dim, hidden_dim, edge_feat_dim, out_dim=1):
        """
        Args:
            node_feat_dim (int): Dimensionality of node features (1 in this case).
            hidden_dim (int): Hidden layer dimension.
            edge_feat_dim (int): Dimensionality of edge features (1 for prize).
            out_dim (int): Output dimension (1 for binary classification).
        """
        super(EdgeClassificationGNN, self).__init__()
        
        # --- First Layer: EdgeConv ---
        # PURPOSE: updates the features of each node by aggregating information from its direct neighbors.
        # MLP takes the concatenated features of a node and its neighbor --> outputs new feature vector
        self.mlp_edgeconv = nn.Sequential(
            nn.Linear(2 * node_feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        # resutls from all neighbors are aggregated - 
        # 'max' selects the most important feature for each dimension
        self.conv1 = EdgeConv(nn=self.mlp_edgeconv, aggr='max')
        
        # --- Second Layer: NNConv ---
        # PURPOSE: updates node features by incorporating edge-specific information
        # MLP maps edge features to weight matrix --> which transforms the neighboring node features
        self.edge_nn = nn.Sequential(
            nn.Linear(edge_feat_dim, hidden_dim * hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim * hidden_dim, hidden_dim * hidden_dim)
        )
        # transformed features from all neighbors aggregated by averaging contriutions from all neihbors
        self.conv2 = NNConv(in_channels=hidden_dim,
                            out_channels=hidden_dim,
                            nn=self.edge_nn,
                            aggr='mean')
        
        # --- Final Edge Classifier ---
        # combine the features of the source node, target node, and edge to predict the edge label
        # classifier is mlp that takes concatenated features and outputs raw score i.e. logit for binary classification
        self.edge_classifier = nn.Sequential(
            nn.Linear(2 * hidden_dim + edge_feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
        
    
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        # input node features and edge index into edge conv
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # input updated node features, edge index (edge_index), and edge features (edge_attr)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        # for edge classification
        src, dst = edge_index
        edge_representation = torch.cat([x[src], x[dst], edge_attr], dim=1)
        logits = self.edge_classifier(edge_representation)
        # return raw logits instead of probabilities
        return logits  

### Training Setup & Loop

- initialize model, loss function, and optimizer
- `BCEWithLogitsLoss` loss function is chosen for binary classification - binary CEL with logitcs i.e raw scores
- `Adam` optimizer users to update model parameters during training

In [14]:
hidden_dim = 8 # dimension of hidden layer
node_feat_dim = 1 # dimension of node features
edge_feat_dim = 1 # dimension of edge features

In [15]:
model = EdgeClassificationGNN(node_feat_dim, hidden_dim, edge_feat_dim)
model = model.to(device)

In [16]:
# use class weighting in the loss function - assign higher weight to positive class i.e. minority class
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

- train the model using the following loop function 
- update model params to minimize loss functions 

In [17]:
def train(model, train_loader, epochs=20):
    model.train()  # set the model to training mode bc of layers beaviours
    # epoch loop:
    for epoch in range(epochs):
        epoch_loss = 0.0
        for batch in train_loader:
            batch = batch.to(device)  # move batch to the GPU
            # clear grads from previous batch
            optimizer.zero_grad()  
            # forward pass + get model predictions based on input batch
            out = model(batch) 
            # compute loss by comparing model pred (out) to ground truth labels
            loss = criterion(out, batch.y)  
            # compute grads wrt model params - backprop
            loss.backward()  
            # update model params using computed gradients - optimization
            optimizer.step() 
            # accumulate loss for current epoch
            epoch_loss += loss.item()  # Accumulate loss    
        # print average loss for the epoch
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

In [18]:
# Train the model.
train(model, train_loader, epochs=10)



Epoch 1/10, Loss: 1.3782
Epoch 2/10, Loss: 0.2979
Epoch 3/10, Loss: 0.0443
Epoch 4/10, Loss: 0.0154
Epoch 5/10, Loss: 0.0047
Epoch 6/10, Loss: 0.0024
Epoch 7/10, Loss: 0.0016
Epoch 8/10, Loss: 0.0011
Epoch 9/10, Loss: 0.0010
Epoch 10/10, Loss: 0.0010


### Evaluation

- Evaluates the model's performance using precision, recall, F1-score, and AUC-ROC.
- These metrics provide a comprehensive assessment of the model's classification performance, especially important for imbalanced datasets

In [19]:
def evaluate_metrics(model, test_loader):
    model.eval()
    y_true, y_pred, y_score = [], [], []

    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out = model(batch)
            preds = (out > 0).float()  # Use 0 as threshold for logits
            y_true.extend(batch.y.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            y_score.extend(out.sigmoid().cpu().numpy())  # Probabilities for AUC-ROC

    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    roc_auc = roc_auc_score(y_true, y_score)
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, AUC-ROC: {roc_auc:.4f}")

In [20]:
evaluate_metrics(model, test_loader)

Precision: 0.5339, Recall: 1.0000, F1: 0.6962, AUC-ROC: 1.0000
