# 06. Model ST-GNN
ST-GNN model implementation.

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.preprocessing import StandardScaler
import utils
import os

# ST-GNN specific data loading
def load_real_translink_data(
    csv_path='data/processed_data/processed_trip_data.csv',
    stops_txt_path='data/google_transit/stops.txt',
    time_bin='15min'
):
    print(f"Loading data from {csv_path}...")
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
    else:
        print("Processed file not found. Please run 02_process_data.ipynb.")
        return None, None, None, None, None, None

    # Filter routes (if needed, though processed data might already be filtered or contain all)
    target_routes = [6641, 6636, 37810, 6622, 6705, 6627, 16718, 6624, 37807, 6617]
    if 'route_id' in df.columns:
        df = df[df['route_id'].isin(target_routes)]
    
    if 'actual_arrival_time' in df.columns:
        df['actual_arrival_time'] = pd.to_datetime(df['actual_arrival_time'], utc=True)
    
    # Skip prepare_trip_data as it is already processed
    # df = utils.prepare_trip_data(df)
    
    if 'arrival_delay_agg' in df.columns:
        df['arrival_delay'] = df['arrival_delay_agg']
        
    if 'scheduled_arrival_time' in df.columns:
        df['scheduled_arrival_time'] = pd.to_datetime(df['scheduled_arrival_time'], utc=True)
        # Recalculate actual_arrival_time if needed or ensure it exists
        if 'actual_arrival_time' not in df.columns:
             df['actual_arrival_time'] = df['scheduled_arrival_time'] + pd.to_timedelta(df['arrival_delay'], unit='s')
            
    unique_stops = df['stop_id'].unique()
    unique_stops.sort()
    stop_to_idx = {stop_id: i for i, stop_id in enumerate(unique_stops)}
    num_stops = len(unique_stops)
    
    # Stops meta
    try:
        stops_meta = pd.read_csv(stops_txt_path)
        stops_meta['stop_id'] = stops_meta['stop_id'].astype(str)
        unique_stops_str = [str(s) for s in unique_stops]
        stops_meta = stops_meta[stops_meta['stop_id'].isin(unique_stops_str)]
        stops_df = pd.DataFrame({'stop_id': unique_stops_str})
        stops_df = stops_df.merge(stops_meta[['stop_id', 'stop_lat', 'stop_lon']], on='stop_id', how='left')
    except:
        stops_df = pd.DataFrame({'stop_id': unique_stops})

    # Edges
    df = df.sort_values(['trip_id', 'actual_arrival_time'])
    df['next_trip_id'] = df['trip_id'].shift(-1)
    df['next_stop_id'] = df['stop_id'].shift(-1)
    
    valid_connections = df[df['trip_id'] == df['next_trip_id']]
    edges = set()
    for _, row in valid_connections.iterrows():
        src = stop_to_idx[row['stop_id']]
        dst = stop_to_idx[row['next_stop_id']]
        if src != dst:
            edges.add((src, dst))
    edge_list = list(edges)
    
    # Time series features
    df.set_index('actual_arrival_time', inplace=True)
    travel_time_matrix = df.pivot_table(index=pd.Grouper(freq=time_bin), columns='stop_id', values='arrival_delay', aggfunc='mean')
    travel_time_matrix = travel_time_matrix.reindex(columns=unique_stops).ffill(limit=4).fillna(0)
    
    travel_time_np = travel_time_matrix.values
    scaler_time = StandardScaler()
    travel_time_np = scaler_time.fit_transform(travel_time_np.reshape(-1, 1)).reshape(travel_time_np.shape)
    
    # Simple features for demo
    node_features = []
    targets = []
    num_timesteps = len(travel_time_matrix)
    
    for t in range(num_timesteps - 1):
        features_t = np.stack([travel_time_np[t]], axis=1) # Simplified
        node_features.append(features_t)
        targets.append(travel_time_np[t+1])
        
    valid_sources = set(src for src, dst in edge_list)
    node_mask = np.array([1.0 if i in valid_sources else 0.0 for i in range(num_stops)], dtype=np.float32)
    
    return stops_df, edge_list, np.array(node_features, dtype=np.float32), np.array(targets, dtype=np.float32), scaler_time, node_mask

stops_df, edge_list, X_all, y_all, scaler_time, node_mask = load_real_translink_data()

In [None]:
# GAT Layer & ST-GNN Model
class GATLayer(layers.Layer):
    def __init__(self, units, heads=1, concat=True, **kwargs):
        super(GATLayer, self).__init__(**kwargs)
        self.units = units
        self.heads = heads
        self.concat = concat
    
    def build(self, input_shape):
        feat_dim = input_shape[0][-1]
        self.W = self.add_weight(shape=(feat_dim, self.units * self.heads), initializer='glorot_uniform', trainable=True)
        self.a = self.add_weight(shape=(1, self.heads, 2 * self.units), initializer='glorot_uniform', trainable=True)
        
    def call(self, inputs):
        h, adj = inputs
        batch_size = tf.shape(h)[0]
        num_nodes = tf.shape(h)[1]
        
        h_prime = tf.matmul(h, self.W)
        h_prime = tf.reshape(h_prime, (batch_size, num_nodes, self.heads, self.units))
        
        a1 = self.a[:, :, :self.units]
        a2 = self.a[:, :, self.units:]
        
        score_i = tf.reduce_sum(h_prime * a1, axis=-1)
        score_j = tf.reduce_sum(h_prime * a2, axis=-1)
        
        score_i = tf.expand_dims(score_i, 2)
        score_j = tf.expand_dims(score_j, 1)
        
        e = tf.nn.leaky_relu(score_i + score_j)
        
        mask = tf.cast(adj, dtype=tf.bool)
        mask = tf.expand_dims(mask, 0)
        mask = tf.expand_dims(mask, -1)
        
        zero_vec = -9e15 * tf.ones_like(e)
        attention = tf.where(mask, e, zero_vec)
        attention = tf.nn.softmax(attention, axis=2)
        
        out = tf.einsum('bijh,bjhu->bihu', attention, h_prime)
        
        if self.concat:
            out = tf.reshape(out, (batch_size, num_nodes, self.heads * self.units))
        else:
            out = tf.reduce_mean(out, axis=2)
        return out

class ST_GNN(models.Model):
    def __init__(self, in_channels, hidden_channels, out_channels, num_nodes, num_heads=2):
        super(ST_GNN, self).__init__()
        self.gat1 = GATLayer(hidden_channels, heads=num_heads, concat=True)
        self.gat2 = GATLayer(hidden_channels, heads=1, concat=False)
        self.lstm = layers.LSTM(hidden_channels, return_sequences=False)
        self.fc = layers.Dense(out_channels)
        
    def call(self, inputs):
        x, adj = inputs
        batch_size = tf.shape(x)[0]
        window_size = tf.shape(x)[1]
        num_nodes = tf.shape(x)[2]
        features = tf.shape(x)[3]
        
        x_reshaped = tf.reshape(x, (batch_size * window_size, num_nodes, features))
        h = self.gat1([x_reshaped, adj])
        h = tf.nn.elu(h)
        h = self.gat2([h, adj])
        h = tf.nn.elu(h)
        
        h = tf.reshape(h, (batch_size, window_size, num_nodes, -1))
        h = tf.transpose(h, perm=[0, 2, 1, 3])
        h_reshaped = tf.reshape(h, (batch_size * num_nodes, window_size, -1))
        
        lstm_out = self.lstm(h_reshaped)
        out = self.fc(lstm_out)
        out = tf.reshape(out, (batch_size, num_nodes, -1))
        return out


In [None]:
# Prepare Data and Train
if X_all is not None:
    def create_dataset(edge_list, node_features, targets, window_size=12):
        num_timesteps = len(node_features)
        num_nodes = node_features.shape[1]
        X, Y = [], []
        for i in range(window_size, num_timesteps):
            X.append(node_features[i-window_size:i])
            Y.append(targets[i].reshape(num_nodes, 1))
        
        adj = np.eye(num_nodes, dtype=np.float32)
        for src, dst in edge_list:
            adj[dst, src] = 1.0
        return np.array(X, dtype=np.float32), np.array(Y, dtype=np.float32), adj

    window_size = 6
    X_data, Y_data, adj_matrix = create_dataset(edge_list, X_all, y_all, window_size)
    
    split_idx = int(len(X_data) * 0.8)
    X_train, X_test = X_data[:split_idx], X_data[split_idx:]
    Y_train, Y_test = Y_data[:split_idx], Y_data[split_idx:]
    
    model = ST_GNN(in_channels=X_train.shape[3], hidden_channels=32, out_channels=1, num_nodes=X_train.shape[2])
    
    print("Training ST-GNN...")
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
    
    # Simplified training loop
    for epoch in range(5):
        with tf.GradientTape() as tape:
            preds = model([X_train[:32], adj_matrix]) # Small batch for demo
            loss = tf.reduce_mean(tf.square(Y_train[:32] - preds))
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        print(f"Epoch {epoch}, Loss: {loss.numpy():.4f}")
