# Peer Recommendation System - Part 3: GNN Model Training

**Course**: SI 670: Applied Machine Learning

**Name**: Yuganshi Agrawal  
**uniqname**: yuganshi  


**Name**: Sai Sneha Siddapura Venkataramappa  
**uniqname**: saisneha

This notebook trains a Graph Neural Network (GraphSAGE) model and compares it with baseline models from Notebook 2.

**Workflow**:
1. Load and process OULAD data
2. Generate complementarity pairs
3. Build student graph
4. Train GraphSAGE model
5. Load baseline models from Notebook 2 (LR, XGBoost)
6. Compare GNN vs Baseline performance

**Dependencies**: Requires Notebook 2 to be run first (generates baseline models)

**Inputs**: 
- Raw OULAD CSV files from `OULAD/` directory
- Saved baseline models from `models/checkpoints/` (from Notebook 2)

**Outputs**:
- `models/checkpoints/gnn_model.pth` - Trained GNN
- `data/processed/gnn_embeddings.pkl` - Student embeddings
- `results/gnn_metrics.json` - GNN performance
- `results/gnn_vs_baselines.json` - Complete comparison


## Setup and Imports


In [1]:
import os, math, random, itertools, sys, time
import numpy as np, pandas as pd
from tqdm.auto import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, classification_report
import xgboost as xgb
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings('ignore')


## Configuration


In [3]:
RNG = 42
np.random.seed(RNG)
random.seed(RNG)
torch.manual_seed(RNG)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RNG)

DATA_DIR = 'OULAD'
PATHS = {
    'student_info': os.path.join(DATA_DIR, 'studentInfo.csv'),
    'student_vle': os.path.join(DATA_DIR, 'studentVle.csv'),
    'vle': os.path.join(DATA_DIR, 'vle.csv'),
    'assessments': os.path.join(DATA_DIR, 'assessments.csv'),
    'student_assessment': os.path.join(DATA_DIR, 'studentAssessment.csv'),
    'courses': os.path.join(DATA_DIR, 'courses.csv'),
    'student_registration': os.path.join(DATA_DIR, 'studentRegistration.csv'),
}

MAX_PAIRS_PER_MODULE = 30000
EMB_DIM = 48
GNN_HIDDEN = 64
HOLDOUT_FRAC = 0.1
BATCH_EDGE = 4096
EPOCHS = 15
LR = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print("GNN-BASED COMPLEMENTARY PEER RECOMMENDATION")
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


GNN-BASED COMPLEMENTARY PEER RECOMMENDATION
Device: cpu


## Data Loading


In [4]:
print("\n Loading data...")
for name, p in PATHS.items():
    if not os.path.exists(p):
        raise FileNotFoundError(f"{p} not found")

student_info = pd.read_csv(PATHS['student_info'])
student_vle = pd.read_csv(PATHS['student_vle'])
vle = pd.read_csv(PATHS['vle'])
assessments = pd.read_csv(PATHS['assessments'])
student_assessment = pd.read_csv(PATHS['student_assessment'])
courses = pd.read_csv(PATHS['courses'])
student_registration = pd.read_csv(PATHS['student_registration'])

if 'week' not in student_vle.columns and 'date' in student_vle.columns:
    student_vle['week'] = (student_vle['date'] // 7).astype(int)

print(f"Loaded {len(student_info)} students, {len(student_vle)} VLE interactions")



 Loading data...
Loaded 32593 students, 10655280 VLE interactions


## Feature Engineering


In [5]:
print("\n Engineering features...")

demo = student_info[['id_student','gender','region','highest_education','imd_band',
                     'age_band','num_of_prev_attempts','studied_credits','disability']].copy()
demo_oh = pd.get_dummies(demo, columns=['gender','region','highest_education',
                                         'imd_band','age_band','disability'],
                         dummy_na=True, drop_first=False).fillna(0)

assess_pivot = student_assessment.pivot_table(
    index='id_student', columns='id_assessment',
    values='score', aggfunc='mean'
).fillna(0)

assess_pivot_z = assess_pivot.apply(lambda c: (c - c.mean()) / (c.std() + 1e-8), axis=0).fillna(0)
assess_diversity = assess_pivot_z.std(axis=1).fillna(0).rename('assess_diversity')

vle_week = student_vle.groupby(['id_student','week'])['sum_click'].sum().reset_index()
week_pivot = vle_week.pivot(index='id_student', columns='week', values='sum_click').fillna(0)
week_mean = week_pivot.mean(axis=1).rename('week_mean')
week_std = week_pivot.std(axis=1).fillna(0).rename('week_std')
week_trend = week_pivot.apply(lambda r: np.polyfit(range(len(r)), r.values, 1)[0] if len(r)>1 else 0, axis=1).rename('week_trend')

vle_types = vle[['id_site','activity_type']].drop_duplicates()
svt = student_vle.merge(vle_types, on='id_site', how='left')
type_counts = svt.groupby(['id_student','activity_type'])['sum_click'].sum().reset_index()
type_entropy = type_counts.groupby('id_student').apply(
    lambda g: -np.sum((g['sum_click']/g['sum_click'].sum()) * np.log(g['sum_click']/g['sum_click'].sum() + 1e-12))
).rename('type_entropy')
type_diversity = type_counts.groupby('id_student')['activity_type'].nunique().rename('type_diversity')

reg = student_registration.groupby('id_student').agg(
    reg_date=('date_registration', 'min'),
    unreg_date=('date_unregistration', 'min')
).fillna({'reg_date':0, 'unreg_date':9999})

features = (demo_oh
    .merge(assess_diversity.reset_index(), on='id_student', how='left')
    .merge(week_mean.reset_index(), on='id_student', how='left')
    .merge(week_std.reset_index(), on='id_student', how='left')
    .merge(week_trend.reset_index(), on='id_student', how='left')
    .merge(type_entropy.reset_index(), on='id_student', how='left')
    .merge(type_diversity.reset_index(), on='id_student', how='left')
    .merge(reg.reset_index(), on='id_student', how='left')
).fillna(0)

module_map = student_info[['id_student','code_module','code_presentation']].drop_duplicates()
features = features.merge(module_map, on='id_student', how='left')

print(f"Base features shape: {features.shape}")



 Engineering features...
Base features shape: (40801, 54)


## Module-wise Normalization


In [6]:
print("\n Normalizing features per module...")

num_cols = [c for c in features.columns if c not in ['id_student','code_module','code_presentation']]
features_norm_list = []

for (mod, pres), group in tqdm(features.groupby(['code_module','code_presentation']), desc="Module-normalize"):
    if len(group) < 2:
        features_norm_list.append(group.copy())
        continue

    scaler = StandardScaler()
    scaled_vals = scaler.fit_transform(group[num_cols].values)
    g = group.copy()
    g[num_cols] = scaled_vals

    ids = g['id_student'].values
    assess_sub = assess_pivot_z.reindex(ids).fillna(0)

    pca_n = min(EMB_DIM, max(2, len(g)-1), assess_sub.shape[1])
    if pca_n >= 2:
        pca = PCA(n_components=pca_n, random_state=RNG)
        try:
            pca_emb = pca.fit_transform(assess_sub.values)
        except Exception:
            pca_emb = np.zeros((len(g), pca_n))
    else:
        pca_emb = np.zeros((len(g), EMB_DIM))

    if pca_emb.shape[1] < EMB_DIM:
        pad = np.zeros((len(g), EMB_DIM - pca_emb.shape[1]))
        pca_emb = np.concatenate([pca_emb, pad], axis=1)

    emb_cols = [f'assess_emb_{i}' for i in range(EMB_DIM)]
    emb_df = pd.DataFrame(pca_emb[:, :EMB_DIM], index=g.index, columns=emb_cols)
    g = pd.concat([g.reset_index(drop=True), emb_df.reset_index(drop=True)], axis=1)

    features_norm_list.append(g)

features_proc = pd.concat(features_norm_list, axis=0).reset_index(drop=True)

# Remove duplicates for node features
features_proc_unique = features_proc.drop_duplicates(subset=['id_student'], keep='first')

node_feature_cols = num_cols + [f'assess_emb_{i}' for i in range(EMB_DIM)]
node_feature_matrix = features_proc_unique[['id_student'] + node_feature_cols].set_index('id_student').fillna(0)

global_pca = PCA(n_components=EMB_DIM, random_state=RNG)
node_features_global = global_pca.fit_transform(node_feature_matrix.values)
node_feat_df = pd.DataFrame(
    node_features_global,
    index=node_feature_matrix.index,
    columns=[f'emb_{i}' for i in range(EMB_DIM)]
).reset_index()

print(f"Node features: {node_feat_df.shape}, explained variance: {global_pca.explained_variance_ratio_.sum():.3f}")

node_feat_lookup = {
    int(r['id_student']): r[[c for c in node_feat_df.columns if c.startswith('emb_')]].values.astype(np.float32)
    for _, r in node_feat_df.iterrows()
}



 Normalizing features per module...


Module-normalize:   0%|          | 0/22 [00:00<?, ?it/s]

Node features: (28785, 49), explained variance: 0.960


## Generate Training Pairs


In [7]:
print("\n Generating pairs with complementarity labels...")

pair_data = []
MAX_PAIRS = MAX_PAIRS_PER_MODULE

for (mod, pres), group in tqdm(features_proc.groupby(['code_module','code_presentation']), desc="Pairs"):
    # CRITICAL FIX: Remove duplicates
    group_dedup = group.drop_duplicates(subset=['id_student'], keep='first')
    students = group_dedup['id_student'].unique().tolist()
    
    if len(students) < 2:
        continue

    all_pairs = list(itertools.combinations(students, 2))
    if len(all_pairs) > MAX_PAIRS:
        all_pairs = random.sample(all_pairs, MAX_PAIRS)

    group_indexed = group_dedup.set_index('id_student')
    assess_cols = [c for c in group_dedup.columns if c.startswith('assess_emb_')]
    
    for a, b in all_pairs:
        if a not in group_indexed.index or b not in group_indexed.index:
            continue
        
        try:
            # Skill complementarity
            vec_a = group_indexed.loc[a, assess_cols].values.astype(float).flatten()
            vec_b = group_indexed.loc[b, assess_cols].values.astype(float).flatten()
            
            if np.allclose(vec_a, 0) and np.allclose(vec_b, 0):
                skill_comp = 0.0
            else:
                skill_comp = float(np.mean(np.abs(vec_a - vec_b)))
            
            # Engagement complementarity
            e_a_series = group_indexed.loc[a, 'week_mean']
            e_b_series = group_indexed.loc[b, 'week_mean']
            
            # Handle Series vs scalar
            e_a = float(e_a_series.iloc[0] if hasattr(e_a_series, 'iloc') else e_a_series)
            e_b = float(e_b_series.iloc[0] if hasattr(e_b_series, 'iloc') else e_b_series)
            
            engage_comp = abs(e_a - e_b)
            overall = 0.7 * skill_comp + 0.3 * engage_comp
            
            pair_data.append({
                'id_i': int(a),
                'id_j': int(b),
                'code_module': mod,
                'code_presentation': pres,
                'skill_comp': skill_comp,
                'engage_comp': engage_comp,
                'comp_score': overall
            })
        except Exception as e:
            continue

pairs_df = pd.DataFrame(pair_data)
q = pairs_df['comp_score'].quantile(0.85)
pairs_df['label'] = (pairs_df['comp_score'] >= q).astype(int)

print(f"Created {len(pairs_df)} pairs, {pairs_df['label'].mean():.2%} positive")



 Generating pairs with complementarity labels...


Pairs:   0%|          | 0/22 [00:00<?, ?it/s]

Created 660000 pairs, 15.00% positive


## Train/Test Split


In [8]:
print("\n Creating train/test split...")

all_students = features_proc_unique['id_student'].unique().tolist()
holdout_students = set(random.sample(all_students, int(len(all_students) * HOLDOUT_FRAC)))
pairs_df['holdout'] = pairs_df.apply(
    lambda r: (r['id_i'] in holdout_students) or (r['id_j'] in holdout_students), axis=1
)

train_pairs = pairs_df[~pairs_df['holdout']].reset_index(drop=True)
test_pairs = pairs_df[pairs_df['holdout']].reset_index(drop=True)
print(f"Train: {len(train_pairs)}, Test: {len(test_pairs)}")



 Creating train/test split...
Train: 533625, Test: 126375


## Build Student Graph


In [9]:
print("\n Building student graph...")

from collections import defaultdict
neighbors = defaultdict(set)

for _, r in train_pairs.iterrows():
    i, j = int(r['id_i']), int(r['id_j'])
    neighbors[i].add(j)
    neighbors[j].add(i)

for nid in node_feat_lookup.keys():
    neighbors.setdefault(int(nid), set())

node_list = sorted(list(node_feat_lookup.keys()))
node_to_idx = {nid: idx for idx, nid in enumerate(node_list)}
N = len(node_list)

node_feats = np.zeros((N, EMB_DIM), dtype=np.float32)
for nid, emb in node_feat_lookup.items():
    if nid in node_to_idx:
        node_feats[node_to_idx[nid], :] = emb
node_feats = torch.tensor(node_feats, device=DEVICE)

neighbors_idx = {node_to_idx[n]: [node_to_idx[m] for m in neighbors[n]] for n in node_list}

print(f"Graph: {N} nodes, avg degree: {np.mean([len(v) for v in neighbors_idx.values()]):.1f}")



 Building student graph...
Graph: 28785 nodes, avg degree: 37.1


## GraphSAGE Model Definition


In [10]:
class GraphSAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.3):
        super().__init__()
        self.fc_self = nn.Linear(in_dim, hidden_dim)
        self.fc_neigh = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim*2, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.bn = nn.BatchNorm1d(hidden_dim*2)

    def aggregate(self, node_indices, neighbors_idx):
        neigh_feats = []
        for n in node_indices:
            n = int(n)
            neigh = neighbors_idx.get(n, [])
            if len(neigh) == 0:
                neigh_feats.append(torch.zeros(node_feats.size(1), device=DEVICE))
            else:
                neigh_feats.append(node_feats[neigh].mean(dim=0))
        return torch.stack(neigh_feats, dim=0)

    def forward(self, node_indices, neighbors_idx_local):
        x_self = node_feats[node_indices]
        x_neigh = self.aggregate(node_indices.tolist(), neighbors_idx_local)

        h_self = F.relu(self.fc_self(x_self))
        h_neigh = F.relu(self.fc_neigh(x_neigh))
        h = torch.cat([h_self, h_neigh], dim=1)
        h = self.bn(h)
        h = self.dropout(h)
        return self.fc2(h)

class EdgePredictor(nn.Module):
    def __init__(self, node_emb_dim, hidden=64, dropout=0.3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(node_emb_dim*2, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2),
            nn.ReLU(),
            nn.Linear(hidden//2, 1)
        )
    
    def forward(self, emb_i, emb_j):
        return self.mlp(torch.cat([emb_i, emb_j], dim=1)).squeeze(1)

gnn = GraphSAGE(in_dim=EMB_DIM, hidden_dim=GNN_HIDDEN, out_dim=GNN_HIDDEN).to(DEVICE)
edge_pred = EdgePredictor(node_emb_dim=GNN_HIDDEN, hidden=GNN_HIDDEN).to(DEVICE)

opt = torch.optim.Adam(list(gnn.parameters()) + list(edge_pred.parameters()), lr=LR, weight_decay=1e-5)
loss_fn = nn.BCEWithLogitsLoss()


## Train GNN Model


In [11]:
print(f"\n Training GNN for {EPOCHS} epochs...")

train_edge_idx = [
    (node_to_idx.get(int(r['id_i']), None), node_to_idx.get(int(r['id_j']), None), r['label'])
    for _, r in train_pairs.iterrows()
]
train_edge_idx = [(a, b, lbl) for (a, b, lbl) in train_edge_idx if a is not None and b is not None]

def edge_batches(edge_list, batch_size=BATCH_EDGE):
    idxs = np.random.permutation(len(edge_list))
    for start in range(0, len(edge_list), batch_size):
        yield [edge_list[i] for i in idxs[start:start+batch_size]]

for epoch in range(1, EPOCHS+1):
    gnn.train(); edge_pred.train()
    losses = []
    
    for batch in edge_batches(train_edge_idx):
        nodes_in_batch = list({a for a, b, l in batch} | {b for a,b,l in batch})
        nodes_tensor = torch.tensor(nodes_in_batch, dtype=torch.long, device=DEVICE)
        node_embs = gnn(nodes_tensor, neighbors_idx)

        map_idx = {nid: i for i, nid in enumerate(nodes_in_batch)}
        emb_i = torch.stack([node_embs[map_idx[a]] for a, b, l in batch])
        emb_j = torch.stack([node_embs[map_idx[b]] for a, b, l in batch])
        labels = torch.tensor([l for a, b, l in batch], dtype=torch.float32, device=DEVICE)

        logits = edge_pred(emb_i, emb_j)
        loss = loss_fn(logits, labels)
        
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(gnn.parameters()) + list(edge_pred.parameters()), 1.0)
        opt.step()
        
        losses.append(loss.item())

    print(f"Epoch {epoch}/{EPOCHS}, Loss: {np.mean(losses):.4f}")



 Training GNN for 15 epochs...
Epoch 1/15, Loss: 0.3218
Epoch 2/15, Loss: 0.1773
Epoch 3/15, Loss: 0.1591
Epoch 4/15, Loss: 0.1510
Epoch 5/15, Loss: 0.1464
Epoch 6/15, Loss: 0.1416
Epoch 7/15, Loss: 0.1388
Epoch 8/15, Loss: 0.1355
Epoch 9/15, Loss: 0.1339
Epoch 10/15, Loss: 0.1312
Epoch 11/15, Loss: 0.1299
Epoch 12/15, Loss: 0.1288
Epoch 13/15, Loss: 0.1279
Epoch 14/15, Loss: 0.1265
Epoch 15/15, Loss: 0.1252


## Evaluate GNN Model


In [13]:
print("\n Evaluating GNN...")

gnn.eval(); edge_pred.eval()

with torch.no_grad():
    all_node_emb = gnn(torch.arange(N, device=DEVICE), neighbors_idx)

def score_pairs(pairs_df_in):
    scores, labels = [], []
    with torch.no_grad():
        for start in range(0, len(pairs_df_in), BATCH_EDGE):
            sub = pairs_df_in.iloc[start:start+BATCH_EDGE]
            idx_pairs = []
            for _, r in sub.iterrows():
                a = node_to_idx.get(int(r['id_i']), None)
                b = node_to_idx.get(int(r['id_j']), None)
                if a is not None and b is not None:
                    idx_pairs.append((a, b))
                else:
                    scores.append(0.0)
                    labels.append(int(r['label']))
                    continue
            
            if idx_pairs:
                emb_i = all_node_emb[[a for a, b in idx_pairs]]
                emb_j = all_node_emb[[b for a, b in idx_pairs]]
                logits = edge_pred(emb_i, emb_j)
                probs = torch.sigmoid(logits).cpu().numpy()
                scores.extend(probs.tolist())
                labels.extend([int(r['label']) for _, r in sub.iloc[:len(idx_pairs)].iterrows()])
    
    return np.array(scores), np.array(labels)

gnn_scores, gnn_labels = score_pairs(test_pairs)

print("GNN RESULTS")
print(f"AUROC: {roc_auc_score(gnn_labels, gnn_scores):.4f}")
print(f"AUPRC: {average_precision_score(gnn_labels, gnn_scores):.4f}")
print("\nClassification Report:")
print(classification_report(gnn_labels, (gnn_scores >= 0.5).astype(int)))



 Evaluating GNN...
GNN RESULTS
AUROC: 0.9764
AUPRC: 0.8930

Classification Report:
              precision    recall  f1-score   support

           0       0.96      0.97      0.97    107202
           1       0.82      0.80      0.81     19173

    accuracy                           0.94    126375
   macro avg       0.89      0.88      0.89    126375
weighted avg       0.94      0.94      0.94    126375



## Load Baseline Models from Notebook 2

Load the pre-trained baseline models (Logistic Regression and XGBoost) that were trained in Notebook 2.

**Note**: This requires Notebook 2 to have been run first.


In [16]:
import pickle
import joblib
from pathlib import Path

print("Loading baseline models from Notebook 2...")

# Check if baseline models exist
baseline_dir = Path('models/checkpoints')
lr_path = baseline_dir / 'lr_model.pkl'
xgb_path = baseline_dir / 'xgb_model.pkl'
baseline_metrics_path = Path('results/baseline_metrics.json')

if not lr_path.exists() or not xgb_path.exists():
    print("ERROR: Baseline models not found!")
    print("Please run Notebook 2 first to generate baseline models.")
    print(f"  Expected: {lr_path}")
    print(f"  Expected: {xgb_path}")
    baseline_metrics = None
    lr_model = None
    xgb_model = None
else:
    # Load baseline models (they're saved as dicts with 'model' and 'scaler' keys)
    with open(lr_path, 'rb') as f:
        lr_data = pickle.load(f)
        lr_model = lr_data['model']
        lr_scaler = lr_data['scaler']
    print(f"  Loaded: {lr_path}")
    
    with open(xgb_path, 'rb') as f:
        xgb_data = pickle.load(f)
        xgb_model = xgb_data['model']
        xgb_scaler = xgb_data['scaler']
    print(f"  Loaded: {xgb_path}")
    
    # Load baseline metrics
    if baseline_metrics_path.exists():
        with open(baseline_metrics_path, 'r') as f:
            baseline_metrics = json.load(f)
        print(f"  Loaded: {baseline_metrics_path}")
        
        # Display baseline performance (note: keys have hyphens and capitals)
        print("\nBaseline Model Performance (from Notebook 2):")
        print(f"  Logistic Regression:")
        lr_metrics = baseline_metrics.get('logistic_regression', {})
        print(f"    ROC-AUC: {lr_metrics.get('ROC-AUC', lr_metrics.get('roc_auc', 'N/A'))}")
        print(f"    PR-AUC:  {lr_metrics.get('PR-AUC', lr_metrics.get('pr_auc', 'N/A'))}")
        
        print(f"  XGBoost:")
        xgb_metrics_loaded = baseline_metrics.get('xgboost', {})
        print(f"    ROC-AUC: {xgb_metrics_loaded.get('ROC-AUC', xgb_metrics_loaded.get('roc_auc', 'N/A'))}")
        print(f"    PR-AUC:  {xgb_metrics_loaded.get('PR-AUC', xgb_metrics_loaded.get('pr_auc', 'N/A'))}")
    else:
        print(f"  Baseline metrics not found at {baseline_metrics_path}")
        baseline_metrics = None



Loading baseline models from Notebook 2...
  Loaded: models/checkpoints/lr_model.pkl
  Loaded: models/checkpoints/xgb_model.pkl
  Loaded: results/baseline_metrics.json

Baseline Model Performance (from Notebook 2):
  Logistic Regression:
    ROC-AUC: 0.7881852266666667
    PR-AUC:  0.47972199260927584
  XGBoost:
    ROC-AUC: 0.8315369866666666
    PR-AUC:  0.5845035938083278


## Compare GNN vs Baseline Models

Compare the GNN model's performance against the baseline models from Notebook 2.


In [17]:
print("\nModel Comparison:")

# Get GNN metrics (make sure they're calculated)
if 'gnn_scores' in globals() and 'gnn_labels' in globals():
    gnn_roc_auc = roc_auc_score(gnn_labels, gnn_scores)
    gnn_pr_auc = average_precision_score(gnn_labels, gnn_scores)
    
    # Display comparison table
    if baseline_metrics:
        print(f"{'Model':<25} {'ROC-AUC':>12} {'PR-AUC':>12} {'Improvement':>15}")
        print("-"*70)
        
        # Baseline LR (handle both key formats)
        lr_m = baseline_metrics.get('logistic_regression', {})
        lr_roc = lr_m.get('ROC-AUC', lr_m.get('roc_auc', 0))
        lr_pr = lr_m.get('PR-AUC', lr_m.get('pr_auc', 0))
        print(f"{'Logistic Regression':<25} {lr_roc:>12.4f} {lr_pr:>12.4f} {'(baseline)':>15}")
        
        # Baseline XGBoost
        xgb_m = baseline_metrics.get('xgboost', {})
        xgb_roc = xgb_m.get('ROC-AUC', xgb_m.get('roc_auc', 0))
        xgb_pr = xgb_m.get('PR-AUC', xgb_m.get('pr_auc', 0))
        print(f"{'XGBoost':<25} {xgb_roc:>12.4f} {xgb_pr:>12.4f} {'(baseline)':>15}")
        
        # GNN
        gnn_improvement = ((gnn_roc_auc - xgb_roc) / xgb_roc * 100) if xgb_roc > 0 else 0
        print(f"{'GraphSAGE (GNN)':<25} {gnn_roc_auc:>12.4f} {gnn_pr_auc:>12.4f} {f'+{gnn_improvement:.1f}%':>15}")
        
        print("="*70)
        
        # Determine best model
        best_model_info = max([
            ('Logistic Regression', lr_roc),
            ('XGBoost', xgb_roc),
            ('GraphSAGE', gnn_roc_auc)
        ], key=lambda x: x[1])
        
        print(f"\nBest Model: {best_model_info[0]} (ROC-AUC: {best_model_info[1]:.4f})")
    else:
        print("Baseline metrics not available - showing GNN results only:")
        print(f"  GNN ROC-AUC: {gnn_roc_auc:.4f}")
        print(f"  GNN PR-AUC:  {gnn_pr_auc:.4f}")
else:
    print("ERROR: GNN scores not found. Make sure evaluation cell ran successfully.")




Model Comparison:
Model                          ROC-AUC       PR-AUC     Improvement
----------------------------------------------------------------------
Logistic Regression             0.7882       0.4797      (baseline)
XGBoost                         0.8315       0.5845      (baseline)
GraphSAGE (GNN)                 0.9764       0.8930          +17.4%

Best Model: GraphSAGE (ROC-AUC: 0.9764)


## Save GNN Results and Comparison

Save GNN model, embeddings, and comprehensive comparison with baselines.


In [20]:
import pickle
import json
from pathlib import Path

print("Saving GNN model and results...")

# Create directories
Path('data/processed').mkdir(parents=True, exist_ok=True)
Path('models/checkpoints').mkdir(parents=True, exist_ok=True)
Path('results').mkdir(parents=True, exist_ok=True)

# Save GNN embeddings
if 'all_node_emb' in globals():
    with open('data/processed/gnn_embeddings.pkl', 'wb') as f:
        pickle.dump({
            'embeddings': all_node_emb.cpu().numpy(),
            'student_ids': node_list,
            'node_to_idx': node_to_idx,
            'embedding_dim': EMB_DIM
        }, f)
    print("  Saved: data/processed/gnn_embeddings.pkl")

# Save graph structure
# if 'neighbors_idx' in globals():
#     with open('data/processed/graph_structure.pkl', 'wb') as f:
#         pickle.dump({
#             'neighbors': {k: v.cpu().numpy() for k, v in neighbors_idx.items()},
#             'node_list': node_list,
#             'N': N
#         }, f)
#     print("  Saved: data/processed/graph_structure.pkl")

# Calculate and save GNN metrics
if 'gnn_scores' in globals() and 'gnn_labels' in globals():
    gnn_roc_auc = float(roc_auc_score(gnn_labels, gnn_scores))
    gnn_pr_auc = float(average_precision_score(gnn_labels, gnn_scores))
    
    gnn_metrics = {
        'model': 'GraphSAGE',
        'ROC-AUC': gnn_roc_auc,
        'PR-AUC': gnn_pr_auc,
        'Accuracy': float(((gnn_scores >= 0.5) == gnn_labels).mean()),
        'test_pairs': len(test_pairs) if 'test_pairs' in globals() else None,
        'embedding_dim': EMB_DIM,
        'hidden_dim': GNN_HIDDEN,
        'device': DEVICE,
        'type': 'classification'
    }
    
    with open('results/gnn_metrics.json', 'w') as f:
        json.dump(gnn_metrics, f, indent=2)
    print("  Saved: results/gnn_metrics.json")
    
    # Save comprehensive comparison
    if baseline_metrics:
        comparison = {
            'baselines': baseline_metrics,
            'gnn': gnn_metrics,
            'improvement_over_best_baseline': {
                'best_baseline': 'XGBoost',
                'best_baseline_roc_auc': max(
                    baseline_metrics.get('logistic_regression', {}).get('ROC-AUC', 0),
                    baseline_metrics.get('xgboost', {}).get('ROC-AUC', 0)
                ),
                'gnn_roc_auc': gnn_roc_auc,
                'absolute_improvement': gnn_roc_auc - max(
                    baseline_metrics.get('logistic_regression', {}).get('ROC-AUC', 0),
                    baseline_metrics.get('xgboost', {}).get('ROC-AUC', 0)
                ),
                'relative_improvement_percent': ((gnn_roc_auc - max(
                    baseline_metrics.get('logistic_regression', {}).get('ROC-AUC', 0),
                    baseline_metrics.get('xgboost', {}).get('ROC-AUC', 0)
                )) / max(
                    baseline_metrics.get('logistic_regression', {}).get('ROC-AUC', 0),
                    baseline_metrics.get('xgboost', {}).get('ROC-AUC', 0)
                ) * 100) if max(
                    baseline_metrics.get('logistic_regression', {}).get('ROC-AUC', 0),
                    baseline_metrics.get('xgboost', {}).get('ROC-AUC', 0)
                ) > 0 else 0
            }
        }
        
        with open('results/gnn_vs_baselines.json', 'w') as f:
            json.dump(comparison, f, indent=2)
        print("  Saved: results/gnn_vs_baselines.json")
else:
    print("  WARNING: GNN metrics not available - evaluation may not have run")

print("\nAll outputs saved successfully!")



Saving GNN model and results...
  Saved: data/processed/gnn_embeddings.pkl
  Saved: results/gnn_metrics.json
  Saved: results/gnn_vs_baselines.json

All outputs saved successfully!


## Summary

### GNN Model Training Complete

This notebook trained a GraphSAGE model and compared it with baseline models from Notebook 2.

**Approach**:
- Used graph structure to learn student embeddings
- 2-layer message passing with neighbor aggregation
- Trained to predict complementarity between student pairs

**Comparison**:
- Loaded pre-trained baselines (Logistic Regression, XGBoost) from Notebook 2
- Evaluated GNN on same test set
- Compared performance metrics

**Results**:
- All metrics saved to `results/gnn_metrics.json`
- Comprehensive comparison saved to `results/gnn_vs_baselines.json`
- Student embeddings saved for use in Notebook 4
