# GNN Training

This notebook trains a **Graph Neural Network (GNN)** to predict next-hour subway ridership across all ~426 MTA station complexes.

**How it works:**
- The NYC subway network is modeled as a **graph** — stations are nodes, physical track connections are edges
- At each hour, the model sees every station's current ridership + time encoding, and predicts what ridership will be one hour from now
- The GNN architecture lets information flow between connected stations, so the model can learn patterns like "when Times Square gets busy, nearby stations get busy too"

**Data pipeline:**
1. `preprocess.py` splits yearly CSVs into train/val/test parquet files and computes per-station normalization stats
2. This notebook loads those parquet files and trains the model

**Split strategy (temporal, no data leakage):**
- 2020–2022: 100% train
- 2023–2025: Jan–Jun → train, Jul–Sep → validation, Oct–Dec → test

Stats (mean/std) are computed **only from training data** to avoid leakage.

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from tqdm import tqdm

## Config

In [None]:
# --- Hyperparameters ---
EPOCHS = 15       # Max number of passes through the full training set
LR = 1e-4         # Learning rate for Adam optimizer (how big each weight update step is)
HIDDEN_DIM = 64   # Size of the hidden layer in the GNN (more = more expressive but slower)
PATIENCE = 4      # Stop training if val loss doesn't improve for this many consecutive epochs

# --- Paths ---
# ROOT points to the project root (one level up from training/)
ROOT = os.path.dirname(os.path.abspath(""))  # from training/ directory
PROC_DIR = os.path.join(ROOT, "data", "processed")
MODEL_DIR = os.path.join(ROOT, "models")
EDGES_PATH = os.path.join(PROC_DIR, "ComplexEdges.csv")   # which stations are connected
CMPLX_PATH = os.path.join(PROC_DIR, "ComplexNodes.csv")   # station complex ID → node index mapping

## Model Definition

The GNN has 3 layers:
1. **GCNConv layer 1** — takes the 3 input features per station and outputs 64-dimensional embeddings. Each station's output incorporates info from its neighbors via graph convolution.
2. **GCNConv layer 2** — another round of message passing. Now each station has info from stations **2 hops away** (neighbor's neighbors).
3. **Linear head** — maps each station's 64-dim embedding down to a single number: the predicted next-hour ridership (in normalized space).

Both GCN layers use **ReLU** activation (clips negatives to 0) to introduce non-linearity.

In [None]:
class GNN(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        # Layer 1: graph convolution from input features → hidden_dim
        # in_dim=3 because each station gets: [ridership_norm, sin(hour), cos(hour)]
        self.conv1 = GCNConv(in_dim, hidden_dim)
        # Layer 2: another graph convolution, hidden_dim → hidden_dim
        # This second layer lets the model see 2-hop neighbors
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        # Output head: maps each station's hidden representation to 1 predicted value
        self.mlp = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index):
        # x shape: (num_nodes, 3) — features for every station
        # edge_index shape: (2, num_edges) — pairs of connected station indices
        h = torch.relu(self.conv1(x, edge_index))   # (num_nodes, hidden_dim)
        h = torch.relu(self.conv2(h, edge_index))   # (num_nodes, hidden_dim)
        return self.mlp(h).squeeze()                 # (num_nodes,) — one prediction per station

## Load Graph Structure

The subway network graph needs two things:
- **Node mapping** (`cmplx_to_node.csv`): maps MTA's station complex IDs (e.g. 611) to sequential indices (0, 1, 2, ...) that PyTorch needs
- **Edges** (`complex_edges.csv`): pairs of connected stations. We add both directions (A→B and B→A) since trains run both ways

In [None]:
# Load the station complex ID → node index mapping
# e.g. complex_id 611 (Times Sq) → node_id 0
cmplx_df = pd.read_csv(CMPLX_PATH)
ComplexNodes = dict(zip(cmplx_df["complex_id"], cmplx_df["node_id"]))
num_nodes = int(cmplx_df["node_id"].max() + 1)

# Load edges (physical track connections between stations)
edges_df = pd.read_csv(EDGES_PATH)
edge_list = []
for _, row in edges_df.iterrows():
    s, e = row["from_complex_id"], row["to_complex_id"]
    if s in ComplexNodes and e in ComplexNodes:
        sn, en = ComplexNodes[s], ComplexNodes[e]
        # Add both directions — undirected graph (trains go both ways)
        edge_list.append([sn, en])
        edge_list.append([en, sn])

# PyG expects edge_index as a (2, num_edges) tensor:
# row 0 = source nodes, row 1 = destination nodes
edge_tensor = torch.tensor(edge_list, dtype=torch.long).T
print(f"Nodes: {num_nodes}, Edges: {edge_tensor.shape[1]}")

## Build Graph Snapshots

This is where the self-supervised learning setup happens. For every pair of consecutive hours in the data:

- **Input (X)**: a matrix of shape `(num_nodes, 3)` where each station gets 3 features:
  - `ridership_norm`: z-score normalized ridership at time $t$ (how far above/below that station's average)
  - `sin_hour`: sine encoding of the hour (so 23:00 and 00:00 are close together)
  - `cos_hour`: cosine encoding of the hour (sin + cos together uniquely identify each hour)
- **Target (y)**: normalized ridership at time $t+1$ for each station

Gaps > 1 hour (e.g. missing data) are skipped to avoid training on impossible predictions.

In [None]:
def build_snapshots(df, num_nodes):
    """Build (features, targets) pairs from consecutive timestamps.
    
    Each snapshot is one training example:
    - features: what the model sees (ridership + time at hour t)
    - targets: what the model must predict (ridership at hour t+1)
    """
    # Group all station data by timestamp for fast lookup
    groups = {t: g for t, g in df.groupby("transit_timestamp")}
    timestamps = sorted(groups.keys())

    features = []
    targets = []

    for t0, t1 in zip(timestamps[:-1], timestamps[1:]):
        # Skip if there's a gap > 1 hour (missing data, wouldn't be a valid prediction)
        if (t1 - t0).total_seconds() > 3600:
            continue

        g0 = groups[t0]  # all stations at time t
        g1 = groups[t1]  # all stations at time t+1

        # Initialize tensors for all nodes (stations without data stay at 0)
        X = torch.zeros(num_nodes, 3)  # input features
        y = torch.zeros(num_nodes)     # target ridership

        # Get node indices for stations present at each timestamp
        idx0 = torch.tensor(g0["node_id"].values)
        idx1 = torch.tensor(g1["node_id"].values)

        # Fill in the 3 features for time t:
        X[idx0, 0] = torch.tensor(g0["ridership_norm"].values.astype(np.float32))  # normalized ridership
        X[idx0, 1] = torch.tensor(g0["sin_hour"].values.astype(np.float32))        # sin(hour)
        X[idx0, 2] = torch.tensor(g0["cos_hour"].values.astype(np.float32))        # cos(hour)

        # Target: normalized ridership at time t+1
        y[idx1] = torch.tensor(g1["ridership_norm"].values.astype(np.float32))

        features.append(X)
        targets.append(y)

    return features, targets

## Load Training Data

The training parquet contains all 2020–2022 data plus Jan–Jun of 2023–2025. Each row has one station's ridership at one timestamp, already normalized and with time encodings computed by `preprocess.py`.

In [None]:
# Load preprocessed training data (already has ridership_norm, sin_hour, cos_hour, node_id)
train_df = pd.read_parquet(os.path.join(PROC_DIR, "train.parquet"))
print(f"Train rows: {len(train_df):,}")

# Convert the flat dataframe into a list of graph snapshots (one per consecutive hour pair)
train_features, train_targets = build_snapshots(train_df, num_nodes)
print(f"Train snapshots: {len(train_features)}")

# Free memory — we only need the snapshot tensors from here on
del train_df

## Load Validation Data

The validation set is Jul–Sep of 2023–2025. The model never trains on this — it's used to detect when the model starts overfitting (memorizing training data instead of learning general patterns).

In [None]:
val_df = pd.read_parquet(os.path.join(PROC_DIR, "val.parquet"))
print(f"Val rows: {len(val_df):,}")

val_features, val_targets = build_snapshots(val_df, num_nodes)
print(f"Val snapshots: {len(val_features)}")
del val_df

## Initialize Model

- **Adam optimizer**: adaptive learning rate optimizer — adjusts step size per parameter, works well out of the box
- **MSE loss**: Mean Squared Error — penalizes large prediction errors more than small ones (squared), which pushes the model to avoid big mistakes

In [None]:
# Create the model with 3 input features and 64-dim hidden layers
model = GNN(in_dim=3, hidden_dim=HIDDEN_DIM)

# Adam: adaptive moment estimation — adjusts LR per parameter based on gradient history
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

# MSE: average of (prediction - truth)² across all stations and snapshots
loss_fn = nn.MSELoss()

os.makedirs(MODEL_DIR, exist_ok=True)
best_model_path = os.path.join(MODEL_DIR, "model.pt")

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Will save best model to: {best_model_path}")

## Training Loop

**How early stopping works:**
1. After each epoch, we check validation loss (performance on unseen Jul–Sep data)
2. If val loss improved → save the model checkpoint (this is the best model so far)
3. If val loss didn't improve for `PATIENCE` consecutive epochs → stop training
4. This prevents **overfitting** — the model memorizing training data instead of learning real patterns

Without early stopping, the model would keep "improving" on training data while getting *worse* at predicting unseen data.

In [None]:
# Track the best validation loss seen so far (starts at infinity)
best_val_loss = float("inf")
patience_counter = 0  # how many epochs since last improvement

print(f"{'Epoch':>6}  {'Train Loss':>11}  {'Val Loss':>11}  {'Status':>10}")
print("-" * 50)

for epoch in range(1, EPOCHS + 1):
    # === TRAINING PHASE ===
    # model.train() enables dropout/batchnorm training behavior (not used here, but good practice)
    model.train()
    total_train_loss = 0

    # Loop through every snapshot: feed in hour t, predict hour t+1, update weights
    for X, y in tqdm(zip(train_features, train_targets), total=len(train_features), desc=f"Epoch {epoch}", leave=False):
        optimizer.zero_grad()           # clear gradients from previous step
        y_hat = model(X, edge_tensor)   # forward pass: predict next-hour ridership for all stations
        loss = loss_fn(y_hat, y)        # compute MSE between predictions and actual values
        loss.backward()                 # backpropagation: compute gradient of loss w.r.t. each weight
        optimizer.step()                # update weights in the direction that reduces loss
        total_train_loss += loss.item()
    train_loss = total_train_loss / len(train_features)

    # === VALIDATION PHASE ===
    # model.eval() + no_grad: no weight updates, just measure how well the model generalizes
    model.eval()
    total_val_loss = 0
    with torch.no_grad():  # disable gradient tracking (saves memory, faster)
        for X, y in zip(val_features, val_targets):
            y_hat = model(X, edge_tensor)
            loss = loss_fn(y_hat, y)
            total_val_loss += loss.item()
    val_loss = total_val_loss / len(val_features)

    # === EARLY STOPPING CHECK ===
    if val_loss < best_val_loss:
        # New best! Save this checkpoint — it's the best model so far
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        status = "★ saved"
    else:
        # No improvement — increment patience counter
        patience_counter += 1
        status = f"wait {patience_counter}/{PATIENCE}"

    print(f"  {epoch:>4}   {train_loss:>11.6f}  {val_loss:>11.6f}  {status:>10}")

    if patience_counter >= PATIENCE:
        # Val loss hasn't improved for PATIENCE epochs — the model is probably overfitting
        print(f"\nEarly stopping at epoch {epoch} (no improvement for {PATIENCE} epochs)")
        break

print(f"\nBest val loss: {best_val_loss:.6f}")
print(f"Model saved:   {best_model_path}")