# 06. Model ST-GNN
ST-GNN model implementation.

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.preprocessing import StandardScaler
import utils

# ST-GNN specific data loading with fixed splits
def load_real_translink_data(
    stops_txt_path='data/google_transit/stops.txt',
    time_bin='15min'
):
    """Load data using fixed train/valid/test splits"""

    # Use shared function for data loading
    df_train, df_valid, df_test, df, split_info = utils.load_split_data_with_combined()
    
    if df is None:
        return None, None, None, None, None, None, None
    
    train_ratio = split_info['train_ratio_actual']

    # Filter routes
    target_routes = [6641, 6636, 37810, 6622, 6705, 6627, 16718, 6624, 37807, 6617]
    df = df[df['route_id'].isin(target_routes)]

    df['actual_arrival_time'] = pd.to_datetime(df['actual_arrival_time'], utc=True)
    df['arrival_delay'] = df['arrival_delay_agg']
    df['scheduled_arrival_time'] = pd.to_datetime(df['scheduled_arrival_time'], utc=True)

    unique_stops = df['stop_id'].unique()
    unique_stops.sort()
    stop_to_idx = {stop_id: i for i, stop_id in enumerate(unique_stops)}
    num_stops = len(unique_stops)

    # Stops meta
    try:
        stops_meta = pd.read_csv(stops_txt_path)
        stops_meta['stop_id'] = stops_meta['stop_id'].astype(str)
        unique_stops_str = [str(s) for s in unique_stops]
        stops_meta = stops_meta[stops_meta['stop_id'].isin(unique_stops_str)]
        stops_df = pd.DataFrame({'stop_id': unique_stops_str})
        stops_df = stops_df.merge(stops_meta[['stop_id', 'stop_lat', 'stop_lon']], on='stop_id', how='left')
    except:
        stops_df = pd.DataFrame({'stop_id': unique_stops})

    # Edges
    df = df.sort_values(['trip_id', 'actual_arrival_time'])
    df['next_trip_id'] = df['trip_id'].shift(-1)
    df['next_stop_id'] = df['stop_id'].shift(-1)

    valid_connections = df[df['trip_id'] == df['next_trip_id']]
    edges = set()
    for _, row in valid_connections.iterrows():
        src = stop_to_idx[row['stop_id']]
        dst = stop_to_idx[row['next_stop_id']]
        if src != dst:
            edges.add((src, dst))
    edge_list = list(edges)

    # Time series features
    df.set_index('actual_arrival_time', inplace=True)
    travel_time_matrix = df.pivot_table(index=pd.Grouper(freq=time_bin), columns='stop_id', values='arrival_delay', aggfunc='mean')
    travel_time_matrix = travel_time_matrix.reindex(columns=unique_stops)

    # Mask Creation
    mask_matrix = (~travel_time_matrix.isna()).astype(np.float32)
    travel_time_matrix_filled = travel_time_matrix.ffill(limit=4).fillna(0)

    travel_time_np = travel_time_matrix_filled.values
    mask_np = mask_matrix.values

    # Use fixed split ratio for scaler fitting
    split_idx = int(len(travel_time_np) * train_ratio)
    train_data = travel_time_np[:split_idx]

    scaler_time = StandardScaler()
    scaler_time.fit(train_data.reshape(-1, 1))
    travel_time_np = scaler_time.transform(travel_time_np.reshape(-1, 1)).reshape(travel_time_np.shape)

    # Features and targets
    node_features = []
    targets = []
    for t in range(len(travel_time_matrix) - 1):
        node_features.append(np.stack([travel_time_np[t]], axis=1))
        targets.append(travel_time_np[t+1])

    valid_sources = set(src for src, dst in edge_list)
    node_mask = np.array([1.0 if i in valid_sources else 0.0 for i in range(num_stops)], dtype=np.float32)

    return stops_df, edge_list, np.array(node_features, dtype=np.float32), np.array(targets, dtype=np.float32), scaler_time, node_mask, mask_np

stops_df, edge_list, X_all, y_all, scaler_time, node_mask, data_mask_all = load_real_translink_data()

In [2]:
# GAT Layer & ST-GNN Model
class GATLayer(layers.Layer):
    def __init__(self, units, heads=1, concat=True, **kwargs):
        super(GATLayer, self).__init__(**kwargs)
        self.units = units
        self.heads = heads
        self.concat = concat
    
    def build(self, input_shape):
        feat_dim = input_shape[0][-1]
        self.W = self.add_weight(shape=(feat_dim, self.units * self.heads), initializer='glorot_uniform', trainable=True)
        self.a = self.add_weight(shape=(1, self.heads, 2 * self.units), initializer='glorot_uniform', trainable=True)
        
    def call(self, inputs):
        h, adj = inputs
        batch_size = tf.shape(h)[0]
        num_nodes = tf.shape(h)[1]
        
        h_prime = tf.matmul(h, self.W)
        h_prime = tf.reshape(h_prime, (batch_size, num_nodes, self.heads, self.units))
        
        a1 = self.a[:, :, :self.units]
        a2 = self.a[:, :, self.units:]
        
        score_i = tf.reduce_sum(h_prime * a1, axis=-1)
        score_j = tf.reduce_sum(h_prime * a2, axis=-1)
        
        score_i = tf.expand_dims(score_i, 2)
        score_j = tf.expand_dims(score_j, 1)
        
        e = tf.nn.leaky_relu(score_i + score_j)
        
        mask = tf.cast(adj, dtype=tf.bool)
        mask = tf.expand_dims(mask, 0)
        mask = tf.expand_dims(mask, -1)
        
        # Optimized to avoid creating a full tensor for zero_vec
        attention = tf.where(mask, e, -9e15)
        attention = tf.nn.softmax(attention, axis=2)
        
        out = tf.einsum('bijh,bjhu->bihu', attention, h_prime)
        
        if self.concat:
            out = tf.reshape(out, (batch_size, num_nodes, self.heads * self.units))
        else:
            out = tf.reduce_mean(out, axis=2)
        return out

class ST_GNN(models.Model):
    def __init__(self, in_channels, hidden_channels, out_channels, num_nodes, num_heads=2):
        super(ST_GNN, self).__init__()
        self.gat1 = GATLayer(hidden_channels, heads=num_heads, concat=True)
        self.gat2 = GATLayer(hidden_channels, heads=1, concat=False)
        self.lstm = layers.LSTM(hidden_channels, return_sequences=False)
        self.fc = layers.Dense(out_channels)
        
    def call(self, inputs):
        x, adj = inputs
        batch_size = tf.shape(x)[0]
        window_size = tf.shape(x)[1]
        num_nodes = tf.shape(x)[2]
        features = tf.shape(x)[3]
        
        x_reshaped = tf.reshape(x, (batch_size * window_size, num_nodes, features))
        h = self.gat1([x_reshaped, adj])
        h = tf.nn.elu(h)
        h = self.gat2([h, adj])
        h = tf.nn.elu(h)
        
        h = tf.reshape(h, (batch_size, window_size, num_nodes, -1))
        h = tf.transpose(h, perm=[0, 2, 1, 3])
        h_reshaped = tf.reshape(h, (batch_size * num_nodes, window_size, -1))
        
        lstm_out = self.lstm(h_reshaped)
        out = self.fc(lstm_out)
        out = tf.reshape(out, (batch_size, num_nodes, -1))
        return out

In [3]:
# Prepare Data and Train
if X_all is not None:
    def create_dataset(edge_list, node_features, targets, masks, window_size=12):
        num_timesteps = len(node_features)
        num_nodes = node_features.shape[1]
        X, Y, M = [], [], []
        for i in range(window_size, num_timesteps):
            X.append(node_features[i-window_size:i])
            Y.append(targets[i].reshape(num_nodes, 1))
            M.append(masks[i].reshape(num_nodes, 1))
        
        adj = np.eye(num_nodes, dtype=np.float32)
        for src, dst in edge_list:
            adj[dst, src] = 1.0
        return np.array(X, dtype=np.float32), np.array(Y, dtype=np.float32), np.array(M, dtype=np.float32), adj

    window_size = 6
    X_data, Y_data, M_data, adj_matrix = create_dataset(edge_list, X_all, y_all, data_mask_all, window_size)
    
    split_idx = int(len(X_data) * 0.8)
    X_train, X_test = X_data[:split_idx], X_data[split_idx:]
    Y_train, Y_test = Y_data[:split_idx], Y_data[split_idx:]
    M_train, M_test = M_data[:split_idx], M_data[split_idx:]
    
    model = ST_GNN(in_channels=X_train.shape[3], hidden_channels=32, out_channels=1, num_nodes=X_train.shape[2])
    
    print("Training ST-GNN...")
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
    
    # Simplified training loop
    batch_size = 4 # Reduced batch size to avoid OOM
    for epoch in range(5):
        with tf.GradientTape() as tape:
            preds = model([X_train[:batch_size], adj_matrix]) 
            # Masked Loss
            loss = tf.reduce_sum(tf.square(Y_train[:batch_size] - preds) * M_train[:batch_size]) / (tf.reduce_sum(M_train[:batch_size]) + 1e-5)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        print(f"Epoch {epoch}, Loss: {loss.numpy():.4f}")

Training ST-GNN...
Epoch 0, Loss: 0.4787
Epoch 1, Loss: 0.4754
Epoch 2, Loss: 0.4555
Epoch 3, Loss: 0.4551
Epoch 4, Loss: 0.4588


In [None]:
# Evaluation on Test Set
evaluation_results = []

print("Evaluating ST-GNN on Test Set...")
scale_factor = scaler_time.scale_[0]

# Collect all predictions and actuals
all_preds = []
all_true = []
all_masks = []

for i in range(0, len(X_test), batch_size):
    end_idx = min(i + batch_size, len(X_test))
    batch_X = X_test[i:end_idx]
    batch_Y = Y_test[i:end_idx]
    batch_M = M_test[i:end_idx]

    preds = model([batch_X, adj_matrix], training=False)

    all_preds.append(preds.numpy())
    all_true.append(batch_Y)
    all_masks.append(batch_M)

# Concatenate all results
all_preds = np.concatenate(all_preds, axis=0)
all_true = np.concatenate(all_true, axis=0)
all_masks = np.concatenate(all_masks, axis=0)

# Inverse transform to original scale (seconds)
n_samples, n_nodes, _ = all_preds.shape
all_preds_flat = all_preds.reshape(-1, 1)
all_true_flat = all_true.reshape(-1, 1)
all_masks_flat = all_masks.reshape(-1, 1)

all_preds_seconds = scaler_time.inverse_transform(all_preds_flat).flatten()
all_true_seconds = scaler_time.inverse_transform(all_true_flat).flatten()
mask_flat = all_masks_flat.flatten().astype(bool)

# Apply mask
y_pred_masked = all_preds_seconds[mask_flat]
y_true_masked = all_true_seconds[mask_flat]

# Unified evaluation
result_stgnn = utils.evaluate_model(
    y_true_masked, y_pred_masked,
    model_name="ST-GNN",
    config={
        "hidden_channels": 32,
        "num_heads": 2,
        "window_size": window_size,
        "scaler": "StandardScaler"
    }
)
evaluation_results.append(result_stgnn)
print(result_stgnn.summary())

In [None]:
# Model Comparison Table and Save Results
utils.display_and_save_results(evaluation_results, 'data/evaluation_results_stgnn.json')