# GNN Evaluation

This notebook evaluates the trained GNN model on **unseen data** (test or validation split) and reports detailed metrics.

**What this does:**
- Loads the best model checkpoint from training
- Runs it on the test (or val) set: for each hour, predicts next-hour ridership for all stations
- Compares predictions to ground truth, both in normalized and real tap-in space
- Reports overall error, error by hour, and per-station breakdown (so you can see which stations are hardest/easiest to predict)

**Why this matters:**
- Shows how well the model generalizes to new time periods (not just memorizing training data)
- Lets you spot patterns in model errors (e.g. always underpredicts at Times Sq at 8am)
- Per-station and per-hour breakdowns help debug model weaknesses

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from tqdm import tqdm

## Config

In [None]:
# --- Evaluation settings ---
SPLIT = "test"      # Which split to evaluate: "test" (default) or "val"
HIDDEN_DIM = 64     # Must match the model you trained

# --- Paths ---
# ROOT points to the project root (one level up from training/)
ROOT = os.path.dirname(os.path.abspath(""))
PROC_DIR = os.path.join(ROOT, "data", "processed")
MODEL_PATH = os.path.join(ROOT, "models", "model.pt")
STATS_PATH = os.path.join(PROC_DIR, "stats.csv")
CMPLX_PATH = os.path.join(PROC_DIR, "ComplexNodes.csv")
EDGES_PATH = os.path.join(PROC_DIR, "ComplexEdges.csv")

# The test split is Oct–Dec 2023–2025 (never seen during training/validation)
# The val split is Jul–Sep 2023–2025 (used for early stopping, not for training)

## Model Definition

The model architecture must **exactly match** what was used in training:
- 2 × GCNConv layers (each with ReLU)
- Linear regression head

If you change the architecture or hidden size, you must retrain and re-save the model.

In [None]:
class GNN(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        # Layer 1: graph convolution from input features → hidden_dim
        self.conv1 = GCNConv(in_dim, hidden_dim)
        # Layer 2: another graph convolution, hidden_dim → hidden_dim
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        # Output head: maps each station's hidden representation to 1 predicted value
        self.mlp = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index):
        # x shape: (num_nodes, 3) — features for every station
        # edge_index shape: (2, num_edges) — pairs of connected station indices
        h = torch.relu(self.conv1(x, edge_index))
        h = torch.relu(self.conv2(h, edge_index))
        return self.mlp(h).squeeze()

## Load Graph Structure & Stats

- **Node mapping**: maps MTA station complex IDs to sequential node indices (needed for PyTorch)
- **Edges**: pairs of connected stations (bidirectional)
- **Stats**: per-station mean and std, used to denormalize predictions back to real tap-in counts

In [None]:
# Load node mapping: complex_id → node_id (0, 1, 2, ...)
cmplx_df = pd.read_csv(CMPLX_PATH)
ComplexNodes = dict(zip(cmplx_df["complex_id"], cmplx_df["node_id"]))
node_to_cmplx = dict(zip(cmplx_df["node_id"], cmplx_df["complex_id"]))
num_nodes = int(cmplx_df["node_id"].max() + 1)

# Load edges (track connections between stations)
edges_df = pd.read_csv(EDGES_PATH)
edge_list = []
for _, row in edges_df.iterrows():
    s, e = row["from_complex_id"], row["to_complex_id"]
    if s in ComplexNodes and e in ComplexNodes:
        sn, en = ComplexNodes[s], ComplexNodes[e]
        # Add both directions (undirected graph)
        edge_list.append([sn, en])
        edge_list.append([en, sn])
edge_tensor = torch.tensor(edge_list, dtype=torch.long).T

# Load per-station normalization stats (mean/std from training set)
stats = pd.read_csv(STATS_PATH)
stn_mean = dict(zip(stats["station_complex_id"], stats["mean"]))
stn_std = dict(zip(stats["station_complex_id"], stats["std"]))

print(f"Nodes: {num_nodes}, Edges: {edge_tensor.shape[1]}")

## Load Model

Loads the best model checkpoint from training. The architecture and hidden size must match exactly.

In [None]:
# Instantiate the model and load trained weights
model = GNN(in_dim=3, hidden_dim=HIDDEN_DIM)
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()  # disables dropout/batchnorm (not used here, but good practice)
print(f"Loaded model from {MODEL_PATH}")

## Load Split Data & Build Snapshots

Loads the test or validation split (parquet file). Each row is one station's ridership at one timestamp, already normalized and with time encodings.

We then build graph snapshots: for every pair of consecutive hours, the model sees hour $t$ and must predict hour $t+1$ for all stations.

In [None]:
# Load the split (test or val) as a dataframe
split_path = os.path.join(PROC_DIR, f"{SPLIT}.parquet")
df = pd.read_parquet(split_path)
print(f"{SPLIT.upper()} set: {len(df):,} rows")
print(f"Date range: {df['transit_timestamp'].min()} → {df['transit_timestamp'].max()}")

# Build graph snapshots: for each consecutive hour, create (features, targets, raw_targets, hour)
groups = {t: g for t, g in df.groupby("transit_timestamp")}
timestamps = sorted(groups.keys())

features, targets, raw_targets, hours = [], [], [], []

for t0, t1 in tqdm(zip(timestamps[:-1], timestamps[1:]), total=len(timestamps)-1, desc="Building snapshots"):
    # Skip if there's a gap > 1 hour (missing data)
    if (t1 - t0).total_seconds() > 3600:
        continue

    g0 = groups[t0]
    g1 = groups[t1]

    X = torch.zeros(num_nodes, 3)
    y = torch.zeros(num_nodes)
    y_raw = torch.zeros(num_nodes)

    idx0 = torch.tensor(g0["node_id"].values)
    idx1 = torch.tensor(g1["node_id"].values)

    X[idx0, 0] = torch.tensor(g0["ridership_norm"].values.astype(np.float32))
    X[idx0, 1] = torch.tensor(g0["sin_hour"].values.astype(np.float32))
    X[idx0, 2] = torch.tensor(g0["cos_hour"].values.astype(np.float32))

    y[idx1] = torch.tensor(g1["ridership_norm"].values.astype(np.float32))
    y_raw[idx1] = torch.tensor(g1["ridership"].values.astype(np.float32))

    features.append(X)
    targets.append(y)
    raw_targets.append(y_raw)
    hours.append(pd.Timestamp(t1).hour)

print(f"Snapshots: {len(features)}")
del df, groups

In [None]:
# features: list of (num_nodes, 3) tensors (input for each hour)
# targets: list of (num_nodes,) tensors (normalized ground truth for next hour)
# raw_targets: list of (num_nodes,) tensors (real tap-in counts for next hour)
# hours: list of int (hour of day for each snapshot)
# These are used for evaluation below.

## Run Inference

For each snapshot (hour), the model predicts next-hour ridership for all stations. We collect both normalized and real (denormalized) predictions for metric calculation.

In [None]:
# For each snapshot, run the model and collect predictions and ground truth
# We store both normalized and denormalized (real tap-in) values for metrics
all_pred_norm, all_true_norm = [], []
all_pred_raw, all_true_raw = [], []
all_hours, all_stations = [], []

with torch.no_grad():
    for X, y_norm, y_raw, hour in tqdm(
        zip(features, targets, raw_targets, hours),
        total=len(features),
        desc="Evaluating",
    ):
        y_hat_norm = model(X, edge_tensor)  # predict next-hour ridership (normalized)

        for node_id in range(num_nodes):
            if node_id not in node_to_cmplx:
                continue

            cmplx_id = node_to_cmplx[node_id]
            true_norm = y_norm[node_id].item()
            pred_norm = y_hat_norm[node_id].item()
            true_raw_val = y_raw[node_id].item()

            # Denormalize prediction: pred_raw = pred_norm * std + mean
            mean = stn_mean.get(cmplx_id, 0)
            std = stn_std.get(cmplx_id, 1)
            pred_raw = pred_norm * std + mean

            all_pred_norm.append(pred_norm)
            all_true_norm.append(true_norm)
            all_pred_raw.append(max(0, pred_raw))  # don't allow negative tap-ins
            all_true_raw.append(true_raw_val)
            all_hours.append(hour)
            all_stations.append(cmplx_id)

pred_norm = np.array(all_pred_norm)
true_norm = np.array(all_true_norm)
pred_raw = np.array(all_pred_raw)
true_raw = np.array(all_true_raw)
hours_arr = np.array(all_hours)
stations_arr = np.array(all_stations)

print(f"Total predictions: {len(pred_raw):,}")

## Overall Metrics

Calculates and prints:
- **MSE/MAE in normalized space** (z-score units)
- **MSE/MAE/RMSE in real tap-in space** (actual number of people)
- **Median absolute error** (robust to outliers)
- **R² score** (how much variance is explained by the model)

This gives a sense of both relative and absolute model accuracy.

In [None]:
# --- Normalized space (z-score units) ---
mse_norm = np.mean((pred_norm - true_norm) ** 2)
mae_norm = np.mean(np.abs(pred_norm - true_norm))

# --- Real tap-in space (actual people) ---
mse_raw = np.mean((pred_raw - true_raw) ** 2)
mae_raw = np.mean(np.abs(pred_raw - true_raw))
rmse_raw = np.sqrt(mse_raw)

# --- R² score ---
ss_res = np.sum((true_raw - pred_raw) ** 2)
ss_tot = np.sum((true_raw - np.mean(true_raw)) ** 2)
r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0

# --- Median absolute error ---
median_ae = np.median(np.abs(pred_raw - true_raw))

print("=" * 50)
print("OVERALL METRICS")
print("=" * 50)
print(f"\n  Normalized space:")
print(f"    MSE  = {mse_norm:.4f}")
print(f"    MAE  = {mae_norm:.4f}")
print(f"\n  Real tap-in space:")
print(f"    MSE   = {mse_raw:.2f}")
print(f"    RMSE  = {rmse_raw:.2f}")
print(f"    MAE   = {mae_raw:.2f} tap-ins")
print(f"    MedAE = {median_ae:.2f} tap-ins")
print(f"    R²    = {r2:.4f}")

## Error by Hour of Day

Shows how model error varies by time of day (e.g. does it struggle more at rush hour?).

For each hour, prints MAE (in tap-ins) and a bar chart for quick visual comparison.

In [None]:
print(f"{'Hour':>6}  {'MAE':>8}  {'Count':>8}  Bar")
print("-" * 45)
for h in range(24):
    mask = hours_arr == h
    if mask.sum() == 0:
        continue
    h_mae = np.mean(np.abs(pred_raw[mask] - true_raw[mask]))
    h_count = mask.sum()
    bar = "█" * int(h_mae / 5)
    print(f"  {h:02d}:00  {h_mae:>8.2f}  {h_count:>8}  {bar}")

## Per-Station Breakdown

For each station, computes:
- MAE (mean absolute error in tap-ins)
- Average true ridership
- MAPE (mean absolute percentage error)
- Number of predictions

Prints the 10 worst and 10 best stations by MAE.

In [None]:
station_errors = {}
for cmplx_id in np.unique(stations_arr):
    mask = stations_arr == cmplx_id
    if mask.sum() < 5:
        continue
    s_mae = np.mean(np.abs(pred_raw[mask] - true_raw[mask]))
    s_avg_ridership = np.mean(true_raw[mask])
    station_errors[cmplx_id] = {
        "mae": s_mae,
        "avg_ridership": s_avg_ridership,
        "mape": (s_mae / (s_avg_ridership + 1e-6)) * 100,
        "n": int(mask.sum()),
    }

sorted_stations = sorted(station_errors.items(), key=lambda x: x[1]["mae"], reverse=True)

print("Top 10 WORST stations (highest MAE):")
print(f"  {'Station':>10}  {'MAE':>8}  {'Avg Ridership':>14}  {'MAPE%':>7}  {'n':>6}")
print(f"  {'-'*10}  {'-'*8}  {'-'*14}  {'-'*7}  {'-'*6}")
for cmplx_id, err in sorted_stations[:10]:
    print(f"  {cmplx_id:>10}  {err['mae']:>8.2f}  {err['avg_ridership']:>14.2f}  {err['mape']:>6.1f}%  {err['n']:>6}")

print("Top 10 BEST stations (lowest MAE):")
print(f"  {'Station':>10}  {'MAE':>8}  {'Avg Ridership':>14}  {'MAPE%':>7}  {'n':>6}")
print(f"  {'-'*10}  {'-'*8}  {'-'*14}  {'-'*7}  {'-'*6}")
for cmplx_id, err in sorted_stations[-10:]:
    print(f"  {cmplx_id:>10}  {err['mae']:>8.2f}  {err['avg_ridership']:>14.2f}  {err['mape']:>6.1f}%  {err['n']:>6}")

In [None]:
print("Top 10 WORST stations (highest MAE):")
print(f"  {'Station':>10}  {'MAE':>8}  {'Avg Ridership':>14}  {'MAPE%':>7}  {'n':>6}")
print(f"  {'-'*10}  {'-'*8}  {'-'*14}  {'-'*7}  {'-'*6}")
for cmplx_id, err in sorted_stations[:10]:
    print(f"  {cmplx_id:>10}  {err['mae']:>8.2f}  {err['avg_ridership']:>14.2f}  {err['mape']:>6.1f}%  {err['n']:>6}")

In [None]:
print("Top 10 BEST stations (lowest MAE):")
print(f"  {'Station':>10}  {'MAE':>8}  {'Avg Ridership':>14}  {'MAPE%':>7}  {'n':>6}")
print(f"  {'-'*10}  {'-'*8}  {'-'*14}  {'-'*7}  {'-'*6}")
for cmplx_id, err in sorted_stations[-10:]:
    print(f"  {cmplx_id:>10}  {err['mae']:>8.2f}  {err['avg_ridership']:>14.2f}  {err['mape']:>6.1f}%  {err['n']:>6}")

## MAPE Distribution

In [None]:
mapes = [v["mape"] for v in station_errors.values()]
print(f"MAPE across stations:")
print(f"  Median = {np.median(mapes):.1f}%")
print(f"  Mean   = {np.mean(mapes):.1f}%")
print(f"  25th   = {np.percentile(mapes, 25):.1f}%")
print(f"  75th   = {np.percentile(mapes, 75):.1f}%")

## MAPE Distribution

Shows the distribution of mean absolute percentage error (MAPE) across all stations. Useful for understanding typical vs. worst-case error.

In [None]:
mapes = [v["mape"] for v in station_errors.values()]
print(f"MAPE across stations:")
print(f"  Median = {np.median(mapes):.1f}%")
print(f"  Mean   = {np.mean(mapes):.1f}%")
print(f"  25th   = {np.percentile(mapes, 25):.1f}%")
print(f"  75th   = {np.percentile(mapes, 75):.1f}%")