# GNN Link Prediction Training

Train a Graph Neural Network (GCN) on the ViMed medical knowledge graph
for **link prediction** -- predicting missing relations between medical entities.

## Pipeline
1. Load the NetworkX graph from pickle
2. Engineer node/edge features
3. Convert to PyTorch Geometric Data
4. Define GCN encoder + link decoder
5. Train with binary cross-entropy
6. Evaluate with AUC & Average Precision
7. Save trained model

## 1. Setup & Imports

In [None]:
import os
import pickle
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomLinkSplit
from sklearn.metrics import roc_auc_score, average_precision_score

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## 2. Load & Explore the Medical Knowledge Graph

In [None]:
# Path to the graph pickle file
GRAPH_PATH = os.path.join("..", "amg_data", "graph_improved.pkl")
MODEL_SAVE_PATH = os.path.join("..", "amg_data", "gnn_link_predictor.pt")

with open(GRAPH_PATH, "rb") as f:
    G = pickle.load(f)

print(f"Graph type: {type(G).__name__}")
print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")

In [None]:
# Analyze node types
node_types = {}
for node, data in G.nodes(data=True):
    ntype = data.get("type", "UNKNOWN")
    node_types[ntype] = node_types.get(ntype, 0) + 1

print("\n--- Node type distribution ---")
for ntype, count in sorted(node_types.items(), key=lambda x: -x[1]):
    print(f"  {ntype}: {count}")

In [None]:
# Analyze edge/relation types
edge_types = {}
for u, v, data in G.edges(data=True):
    rel = data.get("relation", "UNKNOWN")
    edge_types[rel] = edge_types.get(rel, 0) + 1

print("\n--- Edge relation distribution ---")
for rel, count in sorted(edge_types.items(), key=lambda x: -x[1])[:20]:
    print(f"  {rel}: {count}")

## 3. Feature Engineering

- **Node features**: One-hot encoding of node type
- **Edge labels**: Integer encoding of relation type (used for analysis, not training target)

For link prediction we train on edge existence (binary classification).

In [None]:
# Build ordered node list and mappings
nodes_list = list(G.nodes())
node_to_idx = {node: idx for idx, node in enumerate(nodes_list)}
num_nodes = len(nodes_list)

# Build node type vocabulary
all_node_types = sorted(set(
    G.nodes[n].get("type", "UNKNOWN") for n in nodes_list
))
type_to_idx = {t: i for i, t in enumerate(all_node_types)}
num_node_types = len(all_node_types)

print(f"Node type vocabulary ({num_node_types} types): {all_node_types}")

# Create one-hot node feature matrix
x = torch.zeros(num_nodes, num_node_types, dtype=torch.float)
for node in nodes_list:
    idx = node_to_idx[node]
    ntype = G.nodes[node].get("type", "UNKNOWN")
    x[idx, type_to_idx[ntype]] = 1.0

print(f"Node feature matrix shape: {x.shape}")

In [None]:
# Build edge index (convert MultiDiGraph to simple directed edges)
# Deduplicate edges for the same (u, v) pair
seen_edges = set()
edge_src = []
edge_dst = []
edge_confidences = []

for u, v, data in G.edges(data=True):
    u_idx = node_to_idx[u]
    v_idx = node_to_idx[v]
    if (u_idx, v_idx) not in seen_edges:
        seen_edges.add((u_idx, v_idx))
        edge_src.append(u_idx)
        edge_dst.append(v_idx)
        edge_confidences.append(data.get("confidence", 0.5))

edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long)
edge_conf = torch.tensor(edge_confidences, dtype=torch.float)

print(f"Unique directed edges: {edge_index.shape[1]}")
print(f"Edge confidence stats: mean={edge_conf.mean():.3f}, "
      f"min={edge_conf.min():.3f}, max={edge_conf.max():.3f}")

## 4. Convert to PyTorch Geometric Data & Split

In [None]:
# Create PyG Data object
data = Data(x=x, edge_index=edge_index)
data.num_nodes = num_nodes

print(f"PyG Data: {data}")
print(f"  Nodes: {data.num_nodes}")
print(f"  Edges: {data.num_edges}")
print(f"  Node features dim: {data.num_node_features}")

In [None]:
# Split edges into train / val / test
# 85% train, 5% val, 10% test
transform = RandomLinkSplit(
    num_val=0.05,
    num_test=0.10,
    is_undirected=False,
    add_negative_train_samples=True,
    neg_sampling_ratio=1.0,
)

train_data, val_data, test_data = transform(data)

print(f"Train edges: {train_data.edge_label_index.shape[1]} "
      f"(pos + neg, labels sum={train_data.edge_label.sum().int()})")
print(f"Val   edges: {val_data.edge_label_index.shape[1]}")
print(f"Test  edges: {test_data.edge_label_index.shape[1]}")

## 5. Model Definition

**Architecture:**
- 2-layer GCN encoder that produces node embeddings
- Dot-product decoder for link prediction scoring

In [None]:
class GCNEncoder(torch.nn.Module):
    """Two-layer GCN producing node embeddings."""

    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.3):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x


class LinkPredictor(torch.nn.Module):
    """GCN encoder + dot-product link decoder."""

    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.3):
        super().__init__()
        self.encoder = GCNEncoder(in_channels, hidden_channels, out_channels, dropout)

    def encode(self, x, edge_index):
        return self.encoder(x, edge_index)

    def decode(self, z, edge_label_index):
        """Dot-product decoder: score = sigmoid(z_u . z_v)."""
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        return (src * dst).sum(dim=-1)


# Hyperparameters
HIDDEN_DIM = 128
EMBED_DIM = 64
DROPOUT = 0.3
LR = 0.01
EPOCHS = 200

model = LinkPredictor(
    in_channels=num_node_types,
    hidden_channels=HIDDEN_DIM,
    out_channels=EMBED_DIM,
    dropout=DROPOUT,
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

## 6. Training Loop

In [None]:
def train_epoch(model, train_data, optimizer, device):
    """Run one training epoch. Returns the loss value."""
    model.train()
    td = train_data.to(device)

    optimizer.zero_grad()
    z = model.encode(td.x, td.edge_index)
    logits = model.decode(z, td.edge_label_index)
    loss = F.binary_cross_entropy_with_logits(logits, td.edge_label)
    loss.backward()
    optimizer.step()

    return loss.item()


@torch.no_grad()
def evaluate(model, eval_data, device):
    """Evaluate model on given data split. Returns AUC and AP."""
    model.eval()
    ed = eval_data.to(device)

    z = model.encode(ed.x, ed.edge_index)
    logits = model.decode(z, ed.edge_label_index)
    probs = torch.sigmoid(logits).cpu().numpy()
    labels = ed.edge_label.cpu().numpy()

    auc = roc_auc_score(labels, probs)
    ap = average_precision_score(labels, probs)
    return auc, ap

In [None]:
# Training
print("Starting training...")
print(f"{'Epoch':>6} | {'Loss':>8} | {'Val AUC':>8} | {'Val AP':>8}")
print("-" * 42)

best_val_auc = 0.0
history = {"loss": [], "val_auc": [], "val_ap": []}

for epoch in range(1, EPOCHS + 1):
    loss = train_epoch(model, train_data, optimizer, DEVICE)
    history["loss"].append(loss)

    if epoch % 10 == 0 or epoch == 1:
        val_auc, val_ap = evaluate(model, val_data, DEVICE)
        history["val_auc"].append(val_auc)
        history["val_ap"].append(val_ap)
        print(f"{epoch:6d} | {loss:8.4f} | {val_auc:8.4f} | {val_ap:8.4f}")

        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            torch.save({
                "model_state_dict": model.state_dict(),
                "node_types": all_node_types,
                "type_to_idx": type_to_idx,
                "hidden_dim": HIDDEN_DIM,
                "embed_dim": EMBED_DIM,
                "num_node_types": num_node_types,
                "epoch": epoch,
                "val_auc": val_auc,
            }, MODEL_SAVE_PATH)

print(f"\nTraining complete. Best validation AUC: {best_val_auc:.4f}")

## 7. Training Curves

In [None]:
try:
    import matplotlib.pyplot as plt

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Loss curve
    ax1.plot(history["loss"], color="#FF6B6B", linewidth=1.5)
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("BCE Loss")
    ax1.set_title("Training Loss")
    ax1.grid(True, alpha=0.3)

    # AUC/AP curve
    eval_epochs = list(range(1, len(history["val_auc"]) + 1))
    ax2.plot(eval_epochs, history["val_auc"], label="AUC", color="#4ECDC4", linewidth=2)
    ax2.plot(eval_epochs, history["val_ap"], label="AP", color="#FFE66D", linewidth=2)
    ax2.set_xlabel("Evaluation step")
    ax2.set_ylabel("Score")
    ax2.set_title("Validation Metrics")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

except ImportError:
    print("matplotlib not installed. Skipping plots.")

## 8. Test Set Evaluation

In [None]:
# Load best model and evaluate on test set
checkpoint = torch.load(MODEL_SAVE_PATH, map_location=DEVICE, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])

test_auc, test_ap = evaluate(model, test_data, DEVICE)

print("=" * 40)
print("       TEST SET RESULTS")
print("=" * 40)
print(f"  AUC:              {test_auc:.4f}")
print(f"  Average Precision: {test_ap:.4f}")
print(f"  Best epoch:        {checkpoint['epoch']}")
print("=" * 40)

In [None]:
# Detailed predictions on test edges
model.eval()
td = test_data.to(DEVICE)

with torch.no_grad():
    z = model.encode(td.x, td.edge_index)
    test_logits = model.decode(z, td.edge_label_index)
    test_probs = torch.sigmoid(test_logits).cpu().numpy()
    test_labels = td.edge_label.cpu().numpy()

# Count predictions by threshold
threshold = 0.5
preds = (test_probs >= threshold).astype(int)
tp = ((preds == 1) & (test_labels == 1)).sum()
fp = ((preds == 1) & (test_labels == 0)).sum()
tn = ((preds == 0) & (test_labels == 0)).sum()
fn = ((preds == 0) & (test_labels == 1)).sum()

print(f"\nConfusion Matrix (threshold={threshold}):")
print(f"  TP={tp}  FP={fp}")
print(f"  FN={fn}  TN={tn}")
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1 Score:  {f1:.4f}")

## 9. Predict Missing Links

Use the trained model to score node pairs that do NOT have an edge and find
the most likely missing connections.

In [None]:
# Predict top-K most likely missing links
TOP_K = 20

model.eval()
full_data = data.to(DEVICE)

with torch.no_grad():
    z = model.encode(full_data.x, full_data.edge_index)

# Existing edges as a set for fast lookup
existing = set()
ei = full_data.edge_index.cpu().numpy()
for i in range(ei.shape[1]):
    existing.add((ei[0, i], ei[1, i]))

# Sample candidate pairs (full NxN is too large)
# Focus on high-degree nodes for meaningful predictions
degrees = dict(G.degree())
top_nodes = sorted(degrees, key=degrees.get, reverse=True)[:50]
top_indices = [node_to_idx[n] for n in top_nodes]

candidates = []
for i in top_indices:
    for j in top_indices:
        if i != j and (i, j) not in existing:
            candidates.append((i, j))

if candidates:
    cand_src = torch.tensor([c[0] for c in candidates], dtype=torch.long)
    cand_dst = torch.tensor([c[1] for c in candidates], dtype=torch.long)
    cand_index = torch.stack([cand_src, cand_dst]).to(DEVICE)

    with torch.no_grad():
        scores = torch.sigmoid(model.decode(z, cand_index)).cpu().numpy()

    # Get top-K predictions
    top_k_idx = np.argsort(scores)[::-1][:TOP_K]

    print(f"\nTop {TOP_K} predicted missing links:")
    print(f"{'Source':<30} {'Target':<30} {'Score':>8}")
    print("-" * 70)
    for idx in top_k_idx:
        src_node = nodes_list[candidates[idx][0]]
        dst_node = nodes_list[candidates[idx][1]]
        print(f"{src_node:<30} {dst_node:<30} {scores[idx]:8.4f}")
else:
    print("No candidate pairs found.")

In [None]:
print(f"\nModel saved to: {os.path.abspath(MODEL_SAVE_PATH)}")
print("Training notebook complete.")