In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        
import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore", category=Warning)

# 1. GraphSAGE model with Pubmed "Big" Dataset

In [None]:
# 1. library
import torch

find_links = f"https://data.pyg.org/whl/torch-{torch.__version__}.html"

!pip install -q \
    torch-scatter \
    torch-sparse \
    torch-cluster \
    torch-spline-conv \
    torch-geometric \
    -f $find_links

# seed set
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print("Installation Complete.")

In [None]:
# Cora dataset, papers as nodes and their citations as edges
from torch_geometric.datasets import Planetoid

# Loading Cora dataset
# Graph Data Download by Planetoid
dataset = Planetoid(root="./cora_data", name="Pubmed")
data = dataset[0]

# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Training nodes: {sum(data.train_mask).item()}')
print(f'Evaluation nodes: {sum(data.val_mask).item()}')
print(f'Test nodes: {sum(data.test_mask).item()}')
print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')

# The PubMed dataset is a citation network of biomedical publications related to diabetes classification.
# The dataset consists of 19,717 scientific papers and 44,338 citation links (edges) with TF-IDF weighted word vectors.
# Bif dataset

In [None]:
# batche neighbor sampling

from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_networkx

# Create batches with neighbor sampling
train_loader = NeighborLoader(
    data, # 
    num_neighbors=[5, 10], # Sample 5 neighbors for the 1st hop, 10 for the 2nd hop
    batch_size=16,  # Number of seed nodes (center nodes) per batchseed nodes : 16 + 80 (5*16) + 800 (80*10)
    input_nodes=data.train_mask, # Only use training nodes as seed nodes
)

# Inspect the created subgraphs (mini-batches)
for i, subgraph in enumerate(train_loader):
    print(f'Subgraph {i}: {subgraph}')

In [None]:
import networkx as nx
from matplotlib.lines import Line2D # Tool for creating custom legends
from matplotlib.colors import ListedColormap # Tool for custom color maps

# 1. Define Class Names (PubMed Standard Mapping)
class_map = {
    0: 'Experimental',    # lightcoral 
    1: 'Type 1 Diabetes', # lightgreen
    2: 'Type 2 Diabetes'  # lightskyblue
}

# 2. Visualization Setup
fig = plt.figure(figsize=(10, 8))

# Define custom colors: 0->Red, 1->Green, 2->Blue
custom_colors = ['lightcoral', 'lightgreen', 'lightskyblue']
cmap = ListedColormap(custom_colors)

# 3. Loop to Plot Subgraphs
# Use zip() to plot only the first 4 subgraphs
for idx, (subdata, pos) in enumerate(zip(train_loader, [221, 222, 223, 224])):
    
    # Convert PyG data to NetworkX graph
    G = to_networkx(subdata, to_undirected=True)
    
    # Add subplot
    ax = fig.add_subplot(pos)
    ax.set_title(f'Subgraph {idx} (Batch Size: {subdata.batch_size})', fontsize=10, fontweight='bold')
    plt.axis('off')
    
    # Calculate layout (Fix seed for reproducibility)
    pos_layout = nx.spring_layout(G, seed=1, k=0.5)
    
    # Draw Graph
    # Colors are automatically assigned based on node_color values (0, 1, 2) and the cmap
    nx.draw_networkx(G,
                     pos=pos_layout,
                     with_labels=False,
                     node_color=subdata.y,
                     node_size=80,
                     cmap=cmap,         # Use the custom colormap (Red, Green, Blue)
                     vmin=0, vmax=2,    # Fix the range for classes 0 to 2
                     edge_color='gray',
                     alpha=0.8
                     )

# 4. Create a Unified Legend
# Add a legend at the top to explain what each color represents
legend_elements = [
    Line2D([0], [0], marker='o', color='w', label=class_map[0],
           markerfacecolor=custom_colors[0], markersize=8), # Red
    Line2D([0], [0], marker='o', color='w', label=class_map[1],
           markerfacecolor=custom_colors[1], markersize=8), # Green
    Line2D([0], [0], marker='o', color='w', label=class_map[2],
           markerfacecolor=custom_colors[2], markersize=8)  # Blue
]

# Place the legend at the upper center of the figure
fig.legend(handles=legend_elements, loc='upper center', ncol=3, fontsize=10, frameon=False)

plt.tight_layout(rect=[0, 0, 1, 0.95]) # Adjust layout to make room for the legend
plt.show()

# the massive citation network is sampled into manageable subgraphs.
# The graph shows a high degree of heterophily, where connected nodes often belong to different classes.
# GAT is preferred over GCN as its attention mechanism can filter out irrelevant neighbors, unlike GCN's static averaging.

In [None]:
# Implementation of GraphSAGE model architecture and mini-batch training loop

import torch
torch.manual_seed(1)
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

# 1. Helper function to calculate classification accuracy
def accuracy(pred_y, y):
    return ((pred_y == y).sum() / len(y)).item()

# 2. GraphSAGE Model Definition
class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        
        # [Layer 1] First GraphSAGE Convolution Layer
        # Input: dim_in (Number of features) -> Output: dim_h (Hidden dimension)
        # SAGEConv aggregates information from sampled neighbors (mean, max, etc.)
        self.sage1 = SAGEConv(dim_in, dim_h)
        
        # [Layer 2] Second GraphSAGE Convolution Layer
        # Input: dim_h (Hidden dimension) -> Output: dim_out (Number of classes)
        self.sage2 = SAGEConv(dim_h, dim_out)

    def forward(self, x, edge_index):
        # [Step 1] First Aggregation & Transformation
        # Pass input features and connectivity to the first layer
        h = self.sage1(x, edge_index)
        
        # [Step 2] Activation Function (ReLU)
        # Apply non-linearity to the hidden features
        h = torch.relu(h)
        
        # [Step 3] Dropout (Regularization)
        # Randomly zero out 50% of the neurons to prevent overfitting during training
        h = F.dropout(h, p=0.5, training=self.training)
        
        # [Step 4] Second Aggregation (Output Layer)
        # Produce the final output scores (logits) for each class
        h = self.sage2(h, edge_index)
        
        return h

    # 3. Training Loop (Mini-batch Training)
    # Training the model using mini-batches (subgraphs) for scalability
    def fit(self, loader, epochs):
        # Define Loss function (CrossEntropy) and Optimizer (Adam)
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)

        self.train() # Set model to training mode
        
        for epoch in range(epochs+1):
            total_loss = 0
            acc = 0
            val_loss = 0
            val_acc = 0

            # Iterate over batches (Subgraphs)
            # Unlike GCN/GAT which use the full graph, GraphSAGE iterates through small parts
            for batch in loader:
                optimizer.zero_grad() # Clear gradients
                
                # Forward pass on the subgraph (mini-batch)
                out = self(batch.x, batch.edge_index)
                
                # Calculate Loss only on training nodes within the batch
                loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
                
                total_loss += loss.item()
                
                # Calculate Accuracy for the batch
                acc += accuracy(out[batch.train_mask].argmax(dim=1), batch.y[batch.train_mask])
                
                # Backward pass & Weight update
                loss.backward()
                optimizer.step()

                # Validation on the batch 
                val_loss += criterion(out[batch.val_mask], batch.y[batch.val_mask])
                val_acc += accuracy(out[batch.val_mask].argmax(dim=1), batch.y[batch.val_mask])

            # Print metrics every 20 epochs
            if epoch % 20 == 0:
                # Calculate average loss and accuracy across all batches
                avg_loss = total_loss / len(loader)
                avg_acc = acc / len(loader)
                avg_val_loss = val_loss / len(loader)
                avg_val_acc = val_acc / len(loader)
                
                print(f'Epoch {epoch:>3} | Train Loss: {avg_loss:.3f} | Train Acc: {avg_acc*100:>6.2f}% | '
                      f'Val Loss: {avg_val_loss:.2f} | Val Acc: {avg_val_acc*100:.2f}%')

    # 4. Evaluation (Inference)
    @torch.no_grad() # Disable gradient calculation
    def test(self, data):
        self.eval() # Set model to evaluation mode
        
        # Forward pass on the full graph (or test batch)
        out = self(data.x, data.edge_index)
        
        # Calculate accuracy on the test set
        acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
        return acc

In [None]:
# Run the Model

# [Model Instantiation]
# Initialize GraphSAGE model with dataset specifications
# Input: 500 features (PubMed), Hidden: 64 units, Output: 3 classes
graphsage = GraphSAGE(dataset.num_features, 64, dataset.num_classes)
print("\n[Model Structure]")
print(graphsage)

# Training
# Train the model using the NeighborLoader (mini-batches)
print("\n[Starting Training]")
graphsage.fit(train_loader, 200)

# Testing
# Evaluate the model on the full graph using the test mask
print("\n[Final Evaluation]")
acc = graphsage.test(data)
print(f'GraphSAGE Test Accuracy: {acc*100:.2f}%')

In [None]:
# Inductive Node Classification using GraphSAGE
# Train a GraphSAGE model on a multi-graph dataset using neighbor sampling and evaluate on unseen graphs

import torch
from sklearn.metrics import f1_score
from torch_geometric.datasets import PPI
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import GraphSAGE

# 1. Device Configuration
# Use GPU if available for faster computation.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Data Loading & Preprocessing
# Load the PPI dataset.
# Unlike Cora (single graph), this dataset consists of multiple graphs.
# - Train: 20 graphs
# - Val: 2 graphs
# - Test: 2 graphs
train_dataset = PPI(root=".", split='train')
val_dataset = PPI(root=".", split='val')
test_dataset = PPI(root=".", split='test')

# Merge training graphs into a single large graph (Batch).
# This allows 'NeighborLoader' to sample subgraphs across all training graphs simultaneously.
train_data = Batch.from_data_list(train_dataset)

# Initialize NeighborLoader for the training set.
# - batch_size: Number of seed nodes per batch.
# - num_neighbors: [20, 10] means sampling 20 neighbors at 1-hop and 10 at 2-hop.
# - shuffle: Randomly shuffle data for stochastic gradient descent.
train_loader = NeighborLoader(train_data, batch_size=2048, shuffle=True, num_neighbors=[20, 10], num_workers=2, persistent_workers=True)

# Loaders for validation and testing.
# Since validation/test graphs are relatively small, we load them graph-by-graph (batch_size=2 graphs).
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

# 3. Model Definition
# Initialize GraphSAGE.
# GraphSAGE is suitable for inductive learning (generalizing to unseen nodes/graphs).
model = GraphSAGE(
    in_channels=train_dataset.num_features,  # Input feature dimension
    hidden_channels=512,                     # Hidden layer dimension
    num_layers=2,                            # Number of GNN layers (hops)
    out_channels=train_dataset.num_classes,  # Output dimension (Number of classes)
).to(device)

# 4. Loss & Optimizer
# - Loss: BCEWithLogitsLoss (Binary Cross Entropy) is used for multi-label classification.
# - Optimizer: Adam optimizer.
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

# 5. Training Function
def fit(loader):
    model.train() # Set model to training mode

    total_loss = 0
    for data in loader:
        data = data.to(device) # Move batch to device
        optimizer.zero_grad()  # Reset gradients
        
        # Forward pass: Compute node embeddings and predictions
        out = model(data.x, data.edge_index)
        
        # Calculate loss
        loss = criterion(out, data.y)
        
        # Backward pass & Optimization
        total_loss += loss.item() * data.num_graphs # Aggregate loss weighted by graph count
        loss.backward()
        optimizer.step()
        
    # Return average loss
    return total_loss / len(loader.data)

# 6. Evaluation Function
@torch.no_grad() # Disable gradient calculation for inference
def test(loader):
    model.eval() # Set model to evaluation mode

    # Iterate through graphs in the loader (Validation/Test set)
    # Note: 'next(iter(loader))' gets the first batch. For full evaluation, a loop is needed.
    data = next(iter(loader)) 
    
    # Predict on the full graph
    out = model(data.x.to(device), data.edge_index.to(device))
    
    # Convert logits to binary predictions (Threshold = 0)
    preds = (out > 0).float().cpu()

    y, pred = data.y.numpy(), preds.numpy()
    
    # Calculate Micro F1-score (Standard metric for multi-label classification)
    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0

# 7. Execution Loop
print("Starting Inductive GraphSAGE Training...")

for epoch in range(301):
    loss = fit(train_loader)
    val_f1 = test(val_loader) # Evaluate on validation set
    
    if epoch % 50 == 0:
        print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Val F1-score: {val_f1:.4f}')

# Final evaluation on the test set (Unseen data)
print(f'Test F1-score: {test(test_loader):.4f}')

# 2. Inductive Learning on PPI Networks using GraphSAGE

### 1. Background: The "Social Network" of Proteins
- Proteins do not function in isolation. They interact with one another to perform complex biological processes.
- These interactions can be modeled as a Graph.
- Nodes - proteins
- Edges - physical interactions
- Node Features - biological signatures (e.g., positional gene sets, motif gene sets)
- Just like classifying a person's role based on their friends in a social network, we can predict a protein's function based on its interacting neighbors.
### 2. The Challenge: Generalizing to Unseen Data
- In many real-world biological applications, we encounter new tissues or organisms with proteins that the model has never seen before.
- Traditional Transductive Learning: Memorizes the graph structure (e.g., standard GCN). It fails when new nodes are introduced.
- Inductive Learning: Learns "how to aggregate neighbor information" rather than memorizing specific nodes. This allows the model to generalize to completely new graphs.
### 3. Objective
The goal of this notebook is to build a GraphSAGE (Graph Sample and Aggregate) model to predict protein functions in unseen protein-protein interaction (PPI) networks.
- Task - Multi-label Node Classification (predicting 121 biological functions)
- Dataset - PPI Dataset (20 graphs for training, 2 for validation, 2 for testing)
- Key Feature - The test graphs are completely distinct from the training graphs, testing the model's true inductive capability

In [None]:
!pip install rdkit

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv
from sklearn.metrics import f1_score

# Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# PPI Dataset Loading & Preprocessing

path = '.'
ppi_train_dataset = PPI(path, split='train')
ppi_val_dataset = PPI(path, split='val')
ppi_test_dataset = PPI(path, split='test')

print(f"Train Graphs: {len(ppi_train_dataset)}")  # 20 graphs
print(f"Val Graphs:   {len(ppi_val_dataset)}")    # 2 graphs
print(f"Test Graphs:  {len(ppi_test_dataset)}")   # 2 graphs
print(f"Number of features: {ppi_train_dataset.num_features}") # 50 features per protein
print(f"Number of classes:  {ppi_train_dataset.num_classes}")  # 121 labels (multi-label classification)

# 24 human tissues (graphs) - 20 for train, 2 for val, 2 for test
# 1 graph - ave. 2,300 proteins (nodes), 3,4000~6,0000 edges

In [None]:
# Setting up DataLoaders to feed graphs into the model
# PPI graphs are dense and large enough that we treat each graph as a batch.
# batch_size=1 means we load 1 whole tissue graph at a time for training.

ppi_train_loader = DataLoader(ppi_train_dataset, batch_size=1, shuffle=True)
ppi_val_loader = DataLoader(ppi_val_dataset, batch_size=2, shuffle=False)
ppi_test_loader = DataLoader(ppi_test_dataset, batch_size=2, shuffle=False)

# Extracting and inspecting the first batch from the training loader

batch = next(iter(ppi_train_loader)) # Get the first batch

print("the First Batch Structure")
print(f"1. Data Type: {type(batch)}") 
print(f"2. Overall Structure: {batch}")
print(f"3. Number of Nodes: {batch.num_nodes}") # 591 - proteins
print(f"4. Number of Edges: {batch.num_edges}") # 7708 - connectivity
print(f"5. Node Feature Matrix (x) Shape: {batch.x.shape}") # 591, 50 - proteins, features
print(f"6. Label Matrix (y) Shape: {batch.y.shape}") # 591, 121 - proteins, functions
print(f"7. Edge Index Shape: {batch.edge_index.shape}") # 2, 7708 - source & target nodes, connectivity
print(f"8. Batch Vector : {batch.batch}")

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx, k_hop_subgraph

# the first graph from the 20 training graphs
ppi_data = ppi_train_dataset[0] 

print(f"Selected Graph Info: {ppi_data}")
print(f"- Num Nodes: {ppi_data.num_nodes}")
print(f"- Num Edges: {ppi_data.edge_index.shape[1]}")
print(f"- Node Feature Dim: {ppi_data.num_features}")

# Visualization
target_node = 0  # Center node index to visualize
num_hops = 2     # 2-hop neighborhood

# Extract the k-hop subgraph
subset, edge_index, mapping, edge_mask = k_hop_subgraph(
    node_idx=target_node, 
    num_hops=num_hops, 
    edge_index=ppi_data.edge_index, 
    relabel_nodes=True
)

# Convert PyG Data to NetworkX Graph
g = to_networkx(ppi_data)
# Extract the subgraph using the subset of nodes found above
sub_g = g.subgraph(subset.tolist())

plt.figure(figsize=(10, 8))
pos = nx.spring_layout(sub_g, seed=42)  

# Define Node Colors: Center(Red), 1-hop(Blue), 2-hop(Gray)
node_colors = []
for node in sub_g.nodes():
    if node == target_node:
        node_colors.append('lightcoral')    # Center (Target Node)
    elif node in g.neighbors(target_node):
        node_colors.append('lightgreen')   # 1-hop Neighbors
    else:
        node_colors.append('lightskyblue')   # 2-hop Neighbors

# Draw Graph
nx.draw_networkx_nodes(sub_g, pos, node_color=node_colors, node_size=100, alpha=0.9)
nx.draw_networkx_edges(sub_g, pos, edge_color='lightgray', alpha=0.5)
nx.draw_networkx_labels(sub_g, pos, font_size=8, font_color='black')

plt.title(f"Visualizing Neighborhood of Node {target_node} ({num_hops}-hop)", fontsize=15)
plt.axis('off')
plt.show()

In [None]:
# Node Degree Distribution ()
# the connectivity of proteins - Hub identification

d = np.array([val for (node, val) in g.degree()])

# Visualization
plt.figure(figsize=(8, 3))
plt.hist(d, bins=50, color='skyblue', edgecolor='black')

plt.title("Node Degree Distribution (Connectivity)", fontsize=12)
plt.xlabel("Number of Connections (Degree)")
plt.ylabel("Count of Proteins")
plt.grid(axis='y', alpha=0.5)
plt.show()

In [None]:
# the distribution of 50 features associated with each node
# visualize only the first 50 nodes.

plt.figure(figsize=(8, 4))

features = ppi_data.x[:50, :].numpy()  # Slice: (50 Nodes, 50 Features)

# Heatmap
sns.heatmap(features, cmap="viridis", cbar=True)

plt.title("Feature Matrix Heatmap (First 50 Nodes)", fontsize=12)
plt.xlabel("Features (Gene Signatures, Motif sets, etc.)")
plt.ylabel("Nodes (Proteins)")
plt.show()

In [None]:
# GraphSAGE Architectureimport 

import torch.nn as nn

class GraphSAGE_Advanced(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
        super().__init__()
        
        # ModuleList: A list to store layers (useful for loops)
        self.convs = nn.ModuleList() 
        self.norms = nn.ModuleList()
        
        # Input Projection: Transform input features to hidden dimension first
        self.lin = nn.Linear(in_channels, hidden_channels) 
        
        # Stacking Layers (Hidden -> Hidden)
        for _ in range(num_layers):
            # SAGEConv: Aggregates neighbor information (mean aggregation)
            self.convs.append(SAGEConv(hidden_channels, hidden_channels, aggr='mean'))
            # LayerNorm: Stabilizes training by normalizing features
            self.norms.append(nn.LayerNorm(hidden_channels))
        
        # Output Layer: Map hidden features to final output classes
        self.output_layer = nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # [Step 1] Input Projection
        # Resize input features to match hidden dimension for skip connections
        x = self.lin(x)
        
        # Save current state for residual connection (Skip Connection)
        x_prev = x 

        # [Step 2] Message Passing Loop (Deep GraphSAGE)
        for conv, norm in zip(self.convs, self.norms):
            # 1. Graph Convolution (Neighbor Aggregation)
            x = conv(x, edge_index)
            
            # 2. Normalization (Stability)
            x = norm(x)
            
            # 3. Activation Function (ELU)
            # ELU (Exponential Linear Unit) is often robust for deep GNNs
            x = F.elu(x) 
            
            # 4. Dropout (Regularization)
            x = F.dropout(x, p=0.2, training=self.training)
            
            # Add previous info to current info
            x = x + x_prev 
            
            # Update previous state for the next layer
            x_prev = x

        # [Step 3] Final Classification
        x = self.output_layer(x)
        return x

# Model Instantiation
model = GraphSAGE_Advanced(
    in_channels=train_dataset.num_features, 
    hidden_channels=512, 
    out_channels=train_dataset.num_classes,
    num_layers=4
).to(device)

print("[Upgraded Model Structure]")
print(model)

In [None]:
# Loss & Optimizer
# BCEWithLogitsLoss is standard for Multi-label Classification
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [None]:
# Training Function
def train():
    model.train()
    total_loss = 0
    
    for ppi_data in ppi_train_loader:
        ppi_data = ppi_data.to(device)
        optimizer.zero_grad()
        
        out = model(ppi_data.x, ppi_data.edge_index)
        loss = criterion(out, ppi_data.y)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * ppi_data.num_graphs
        
    return total_loss / len(ppi_train_loader.dataset)

In [None]:
# Evaluation Function (Micro F1 Score)
@torch.no_grad()
def test(loader):
    model.eval()
    
    ys, preds = [], []
    
    for ppi_data in loader:
        ppi_data = ppi_data.to(device)
        out = model(ppi_data.x, ppi_data.edge_index)
        
        # Apply Sigmoid to get probabilities
        # Threshold 0.5: If prob > 0.5, predict as Class 1
        pred = (torch.sigmoid(out) > 0.5).float()
        
        ys.append(ppi_data.y.cpu())
        preds.append(pred.cpu())
        
    # Concatenate all graphs in the batch
    y = torch.cat(ys, dim=0).numpy()
    pred = torch.cat(preds, dim=0).numpy()
    
    # Calculate Micro-F1 Score
    # Micro F1 calculates metrics globally by counting the total true positives, 
    # false negatives and false positives. Good for imbalanced classes.
    return f1_score(y, pred, average='micro')

In [None]:
# Starting GraphSAGE Training

train_losses = []
val_f1_scores = []

epochs = 200 

for epoch in range(1, epochs + 1):
    loss = train()
    val_f1 = test(ppi_val_loader)
    
    train_losses.append(loss)
    val_f1_scores.append(val_f1)
    
    if epoch % 20 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Micro-F1: {val_f1:.4f}')

# Final Test
test_f1 = test(ppi_test_loader)
print(f'\nFinal Test Micro-F1 Score: {test_f1:.4f}')

In [None]:
# Visualization

plt.figure(figsize=(8, 3))

# Loss Curve
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', color='red')
plt.title('Training Loss (BCE)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Metric Curve (F1)
plt.subplot(1, 2, 2)
plt.plot(val_f1_scores, label='Validation F1', color='green')
plt.title('Validation Micro-F1 Score')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()