# GNN Training

This notebook trains a **Graph Neural Network (GNN)** with the `DirSAGEEmbRes` architecture, to predict next-hour subway ridership across MTA station complexes. 
`DirSAGEEmbRes` is used because it has direction-aware message passing, station-specific identity, and stable deep propagation. 

Ridership at a station depends on:
- It's own history + time-of-day/day-of-week (temporal)
- Nearby stations / connected lines (spatial spillover)
- Directional effects (inflow and outflow of commuters)

**Learnable Node Embeddings:**
The architecture uses learnable node embeddings to give the model a trainable station identity vector, representing things like baseline demand (quiet vs busy stop), being a transfer hub, neighbourhood differences etc.

**Two SAGEConv Layers:**
Separate convolutions for incoming and outgoing edges. This lets the model learn different transformation weights for each direction. For example, a station might have many more incoming than outgoing riders in the morning (e.g. industrial areas) or have many more incoming than outgoing riders at night (e.g. residential areas). Using two SAGEConv streams allows for more expressive modelling of asymmetric flow patterns. Using two layers allows for a node to incorporate information from neighbours that are up to 2 hops away.

**Residual Connections:**
GNN stacks can often suffer from over-smoothing (node representations becoming too similar after multiple message passing steps) and harder optimisation (gradient fading, unstable training). Residuals help to preserve the node's original signal (its own ridership and embedding). They let each layer learn a correction instead of rewriting everythingm making it safe to stack layers without removing station identity. Over-smoothing is an even greater risk in this scenario since many stations are connected through short paths. Should embeddings lose their distinctness, the model may lose the ability to distinguish individual station dynamics. 

**ReLU Activation:**
Clips all negative values to 0, leaves positive values unchanged. Without ReLU activation, stacking layers would still result in an effectively linear layer. This makes computation non-linear, so stacking layers increases what the model can represent. ReLU keeps positive signals, zeros negative signals, and introduces non-linearity before the next layer.

**Key points:**
- The NYC subway is modeled as a graph: stations are nodes, track connections are edges.
- At each hour t, the model builds a station's representation from **station identity (embedding), current local state (ridership + time), neighbours's states (message passing) while not forgetting itself (residual)**. It uses this to predict ridership at t+1.
- Information flows between connected stations, allowing the model to learn spatial and temporal patterns.

**Data pipeline:**
1. `preprocess.py` splits yearly CSVs into train/val/test parquet files and computes per-station normalization stats.
2. This notebook loads those parquet files and trains the GNN model.

**Split strategy (temporal, no data leakage):**
- 2020–2022: train
- 2023:      val
- 2024:      test

Stats are computed from training data to avoid leakage.

In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Config

In [None]:
# Hyperparams
EPOCHS = 15       # Max number of passes through the full training set
LR = 1e-4         # Learning rate for Adam optimizer
HIDDEN_DIM = 64   # Size of the model's internal node representation after a message-passing layer. More -> model can learn more complex patterns, but slower training and higher risk of overfitting
PATIENCE = 4      # Stop training if val loss doesn't improve for this many consecutive epochs

# Paths
ROOT = os.path.dirname(os.path.abspath(""))
PROC_DIR = os.path.join(ROOT, "data", "processed")
MODEL_DIR = os.path.join(ROOT, "new_models")
EDGES_PATH = os.path.join(PROC_DIR, "ComplexEdges.csv")
CMPLX_PATH = os.path.join(PROC_DIR, "ComplexNodes.csv")

## Model Definition

The model used here is **DirSAGEEmbRes**, a custom GNN architecture. It is a directional SAGE variant.
- Uses learnable node embeddings for each station.
- Two SAGEConv layers for both incoming and outgoing edges.
- Residual connections to reduce over-smoothing.
- Final linear layer outputs the predicted next-hour normalized ridership.

In [None]:
class DirSAGEEmbRes(nn.Module):
    # Defining the layers
    def __init__(self, num_nodes: int, in_dim: int, hidden_dim: int, emb_dim: int = 16):
        super().__init__()
        """
        Create a learnable table of vectors shaped [num_nodes, emb_dim]
        Each node index i has its own trainable embedding vector node_emb[i]
        Station identity features that the model learns
        """
        self.node_emb = nn.Embedding(num_nodes, emb_dim)
        # Effective input dimension to GNN = in_dim columns of input node features + emb_dim
        d0 = in_dim + emb_dim
        
        """
        Incoming GraphSAGE layers
        Message-passing layers that aggregate information from neighbours. Node features + edges -> output new node features
        Layers run using edge_in
        """
        self.in1 = SAGEConv(d0, hidden_dim)
        self.in2 = SAGEConv(hidden_dim, hidden_dim)

        """
        Outgoing GraphSAGE layers
        Layers run using edge_out
        """
        self.out1 = SAGEConv(d0, hidden_dim)
        self.out2 = SAGEConv(hidden_dim, hidden_dim)

        """
        At the end, concatenate the outputs of incoming and outgoing layers -> gives a vector of size 2 * hidden_dim for each node
        Linear layer to produce final prediction from this concatenated vector.
        """
        self.lin = nn.Linear(2 * hidden_dim, 1)

    """
    x = node features matrix [num_nodes, in_dim]
    edge_in, edge_out = edge index tensors for incoming and outgoing edges, shape [2, num_edges]
    """
    def forward(self, x, edge_in, edge_out):
        """
        Create node indices [0...num_nodes-1]
        Used to look up embeddings for each node
        device ensures that the indices are on the same device (CPU/GPU) as the input features
        """
        node_ids = torch.arange(x.size(0), device=x.device)

        """
        Concatenate input features with node embeddings
        Each node now has real features (ridership/time) and a learnable station identity vector
        """
        x = torch.cat([x, self.node_emb(node_ids)], dim=1)

        # Incoming stream - message-passing layers using edge_in
        h_in1 = torch.relu(self.in1(x, edge_in))     # For each node, gather neighbour features (message passing) using edges in edge_in, aggregate + transform into new representation, producing +ve/-ve numbers.
                                                     # ReLU keeps positive signals and zeros negative signals.
        h_in2 = torch.relu(self.in2(h_in1, edge_in)) # Repeat with second layer but now using learned representation from layer 1.
        h_in  = h_in2 + h_in1                        # Add layer 1 and layer 2 outputs to get final incoming representation. 
        # Prevent 2nd layer from over-smoothing and losing information from layer 1, 2nd layer becomes a refinement. 
        # Shape: [num_nodes, hidden_dim]
        
        # Same but for outgoing stream
        h_out1 = torch.relu(self.out1(x, edge_out))
        h_out2 = torch.relu(self.out2(h_out1, edge_out))
        h_out  = h_out2 + h_out1
        # Shape: [num_nodes, hidden_dim]

        # Concatenate in/out representations and predict
        h = torch.cat([h_in, h_out], dim=-1)
        # Shape: [num_nodes, 2 * hidden_dim]
        # [num_nodes, 2 * hidden_dim] -> [num_nodes, 1] -> [num_nodes]
        return self.lin(h).squeeze(-1)

## Load Graph Structure
- **Node mapping** (`ComplexNodes.csv`): Maps station complex IDs to node indices for PyTorch.
- **Edges** (`ComplexEdges.csv`): Pairs of connected stations. Both directions are included. Self-loops are added for stability.

In [4]:
# Load the station complex ID -> node index mapping
cmplx_df = pd.read_csv(CMPLX_PATH)
ComplexNodes = dict(zip(cmplx_df["complex_id"], cmplx_df["node_id"]))

num_nodes = len(ComplexNodes)

# Load edges
edges_df = pd.read_csv(EDGES_PATH)

edge_in_list = []   # from -> to
edge_out_list = []  # to -> from

for _, row in edges_df.iterrows():
    s, e = row["from_complex_id"], row["to_complex_id"]
    if s in ComplexNodes and e in ComplexNodes:
        sn, en = ComplexNodes[s], ComplexNodes[e]
        edge_in_list.append([sn, en])
        edge_out_list.append([en, sn])

# Add self-loops (each node is connected to itself)
for i in range(num_nodes):
    edge_in_list.append([i, i])
    edge_out_list.append([i, i])

# edge_in_list, edge_out_list are lists of [from_node, to_node] pairs for each edge, including self-loops.
# edge_in_list = [[u1,v1], [u2,v2], ..., [i,i], ...]

# Convert lists of edges into a tensor of integers with shape [num_edges, 2], then transpose to shape [2, num_edges]
# PyTorch Geometric expects edge_index as row 0 = from_node, row 1 = to_node
# v aggregates messages from its incoming edges, u1, u2, ... are the source nodes sending messages to v
edge_in = torch.tensor(edge_in_list, dtype=torch.long).T
edge_out = torch.tensor(edge_out_list, dtype=torch.long).T

print(f"Nodes: {num_nodes}, Edge_in: {edge_in.shape[1]}, Edge_out: {edge_out.shape[1]}")

Nodes: 424, Edge_in: 976, Edge_out: 976


## Build Graph Snapshots

Build (feature, target) pairs from consecutive timestamps.
Each snapshot is one supervised training example: features at hour t, target at hour t+1

For every pair of consecutive hours in the data:
- Input (X): a matrix of shape `(num_nodes, 5)` for each station:
  - `ridership_norm`: z-score normalized ridership at time $t$
  - `sin_hour`, `cos_hour`: sine/cosine encoding of the hour
  - `sin_dow`, `cos_dow`: sine/cosine encoding of the day of week
- Target (y): normalized ridership at time $t+1$ for each station

The graph edges are defined in terms of node indices, so the feature tensor must use the same indexing scheme. Some timestamps might be missing certain stations, use a dense tensor to represent a missing station row as all zeros. 

This setup enables the model to learn both temporal and spatial dependencies.

In [None]:
def build_snapshots(df, num_nodes):
    """
    Build (features, targets) pairs from consecutive timestamps.

    Each snapshot is one supervised training example:
    - features at hour t: [ridership_norm, sin_hour, cos_hour, sin_dow, cos_dow]
    - targets at hour t+1: [ridership_norm]
    """
    # Split df into groups by timestamp
    groups = {t: g for t, g in df.groupby("transit_timestamp")}
    timestamps = sorted(groups.keys())

    # List of X and y tensors (one per t -> t+1 pair)
    features = []
    targets = []

    # Iterate through consecutive timestamp pairs (t, t+1), zip(all except last, all except first)
    for t0, t1 in zip(timestamps[:-1], timestamps[1:]):

        # Only consider pairs that are one hour apart
        if (t1 - t0).total_seconds() > 3600:
            continue

        # Rows for time t0 and t1
        g0 = groups[t0]
        g1 = groups[t1]

        # Allocate dense, node-aligned tensors
        # X: features at time t0, y: target at time t1
        X = torch.zeros(num_nodes, 5)
        y = torch.zeros(num_nodes)

        # Node indices at each timestamp, for scattering values into tensors)
        # idx0 -> which row in X corresponds to the stations in time t0
        # idx1 -> which row in y corresponds to the stations in time t1
        idx0 = torch.tensor(g0["node_id"].values)
        idx1 = torch.tensor(g1["node_id"].values)

        # Fill X rows for nodes at t0, y for nodes at t1
        X[idx0, 0] = torch.tensor(g0["ridership_norm"].values.astype(np.float32))
        X[idx0, 1] = torch.tensor(g0["sin_hour"].values.astype(np.float32))
        X[idx0, 2] = torch.tensor(g0["cos_hour"].values.astype(np.float32))
        X[idx0, 3] = torch.tensor(g0["sin_dow"].values.astype(np.float32))
        X[idx0, 4] = torch.tensor(g0["cos_dow"].values.astype(np.float32))
        y[idx1] = torch.tensor(g1["ridership_norm"].values.astype(np.float32))

        # Store (X_t, y_t+1) pair
        features.append(X)
        targets.append(y)

    return features, targets

## Load Training Data

The training parquet contains all 2020–2022 data. Each row has one station's ridership at one timestamp, already normalized and with time encodings computed by `preprocess.py`.

The data is converted into a list of graph snapshots, each representing a pair of consecutive hours for all stations.

In [6]:
# Load preprocessed training data
train_df = pd.read_parquet(os.path.join(PROC_DIR, "train.parquet"))
print(f"Train rows: {len(train_df):,}")

# Convert the df into a list of graph snapshots (one per consecutive hour pair)
train_features, train_targets = build_snapshots(train_df, num_nodes)
print(f"Train snapshots: {len(train_features)}")

# Free memory, only the snapshot tensors are needed
del train_df

Train rows: 10,522,218
Train snapshots: 26273


## Load Validation Data

The validation set is 2023. Used to detect overfitting. Validation data is also converted into graph snapshots for evaluation after each epoch.

In [7]:
val_df = pd.read_parquet(os.path.join(PROC_DIR, "val.parquet"))
print(f"Val rows: {len(val_df):,}")

val_features, val_targets = build_snapshots(val_df, num_nodes)
print(f"Val snapshots: {len(val_features)}")

# Free memory, only the snapshot tensors are needed
del val_df

Val rows: 3,639,413
Val snapshots: 8757


## Initialize Model

- **DirSAGEEmbRes**: Graph neural network with node embeddings, two SAGEConv layers for incoming and outgoing edges, and residual connections to reduce over-smoothing.
- **Adam optimizer**: Adaptive learning rate optimiser, automatically scales weights using running estimates. Adjusts step size per parameter. Requires less tuning, works better out-of-the-box.
- **MSE loss**: Penalise large errors more, training focuses on reducing large errors. 
- **Model checkpointing**: Best model (lowest validation loss) is saved to disk.

In [8]:
# Create the model with 5 input features and hidden layers
model = DirSAGEEmbRes(num_nodes=num_nodes, in_dim=5, hidden_dim=HIDDEN_DIM, emb_dim=16)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.MSELoss()

os.makedirs(MODEL_DIR, exist_ok=True)
best_model_path = os.path.join(MODEL_DIR, "model.pt")

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Will save best model to: {best_model_path}")

Model parameters: 28,929
Will save best model to: c:\Users\setho\PersonalProjects\hush\new_models\model.pt


## Training Loop

**Early stopping:**
1. After each epoch, validation loss is checked.
2. If val loss improves, the model checkpoint is saved. This saves the best model so far
3. If val loss doesn't improve for `PATIENCE` consecutive epochs, training stops early to prevent overfitting.

Ensures that the best model is saved. Avoids overfitting.

In [9]:
# Track best validation loss
best_val_loss = float("inf")
patience_counter = 0  # Epochs since last improvement

print(f"{'Epoch':>6}  {'Train Loss':>11}  {'Val Loss':>11}  {'Status':>10}")
print("-" * 50)

# Epoch = one full pass through all training snapshots.
for epoch in range(1, EPOCHS + 1):
    model.train()         # Put model in training mode (enables dropout/batchnorm, not used here)
    total_train_loss = 0  # Accumulate mean training loss over the epoch
    
    # TRAINING
    # Loop through every snapshot: input hour t, predict hour t+1, update weights
    for X, y in tqdm(zip(train_features, train_targets), total=len(train_features), desc=f"Epoch {epoch}", leave=False):
        optimizer.zero_grad()               # Clear old gradients
        y_hat = model(X, edge_in, edge_out) # Forward pass: Returns one prediction per node for time t+1
        loss = loss_fn(y_hat, y)            # Compute loss by comparing predictions to true targets for time t+1
        loss.backward()                     # Backpropagation: Compute gradients of loss wrt every parameter
        optimizer.step()                    # Update parameters using gradients and learning rate
        total_train_loss += loss.item()     # Accumulate loss for this snapshot
    train_loss = total_train_loss / len(train_features)

    # VALIDATION
    model.eval() # Put model in eval mode (disables dropout/batchnorm, not used here)
    total_val_loss = 0
    with torch.no_grad():  # Disable gradient tracking, do not accidentally backprop during validation, uses less memory
        for X, y in zip(val_features, val_targets):
            y_hat = model(X, edge_in, edge_out) # Compute prediction 
            loss = loss_fn(y_hat, y)            # Compute MSE
            total_val_loss += loss.item()       # Average loss
    val_loss = total_val_loss / len(val_features)

    # EARLY STOPPING CHECK
    if val_loss < best_val_loss:
        # New best validation performance, save model weights and reset patience counter
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        status = "saved"
    else:
        # No improvement, increment patience counter
        patience_counter += 1
        status = f"wait {patience_counter}/{PATIENCE}"

    print(f"  {epoch:>4}   {train_loss:>11.6f}  {val_loss:>11.6f}  {status:>10}")

    # val loss has not improved for PATIENCE consecutive epochs, stop training to avoid overfitting and wasting time
    if patience_counter >= PATIENCE:
        print(f"\nEarly stopping at epoch {epoch} (no improvement for {PATIENCE} epochs)")
        break

print(f"\nBest val loss: {best_val_loss:.6f}")
print(f"Model saved:   {best_model_path}")

 Epoch   Train Loss     Val Loss      Status
--------------------------------------------------


                                                               

     1      0.176997     0.179397     ★ saved


                                                               

     2      0.090805     0.154851     ★ saved


                                                               

     3      0.079282     0.140128     ★ saved


                                                               

     4      0.073461     0.129812     ★ saved


                                                               

     5      0.069737     0.123689     ★ saved


                                                               

     6      0.067052     0.118758     ★ saved


                                                               

     7      0.065047     0.114990     ★ saved


                                                               

     8      0.063570     0.111564     ★ saved


                                                               

     9      0.062494     0.109464     ★ saved


                                                                

    10      0.061086     0.107900     ★ saved


                                                                

    11      0.059882     0.106413     ★ saved


                                                                

    12      0.058819     0.105219     ★ saved


                                                                

    13      0.057855     0.103947     ★ saved


                                                                

    14      0.057011     0.103017     ★ saved


                                                                

    15      0.056299     0.102236     ★ saved

Best val loss: 0.102236
Model saved:   c:\Users\setho\PersonalProjects\hush\new_models\model.pt
