### Defs

In [None]:
import torch
import joblib
import json
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from pathlib import Path
import torch.nn.functional as F
from torch_geometric.nn import NNConv, global_mean_pool
from torch_geometric.data import Data, Batch
from sklearn.preprocessing import StandardScaler, FunctionTransformer
from utils.time import name_to_day
from data_processing import (
    load_train_test_split,
    load_graph,
    load_delays,
    load_stop_id_to_idx,
    load_trip_idx_list
)
root = Path('/mnt/c/Users/rdsup/Desktop/vitmma19-pw-delay-detection')

data_path = root / 'data'
delays_path = data_path / 'delays'
graphs_path = data_path / 'graphs'

### Calc scalers

In [169]:
graphs_train, delays_train = [],[]
records = load_train_test_split(data_path)
for record_name in tqdm(records['train']):
    for load_path in (graphs_path / record_name).iterdir():
        graphs_train.append(torch.load(load_path, weights_only=False))
        
    delays_train.append(load_delays(record_name,delays_path)['delay'].to_numpy())

  0%|          | 0/14 [00:00<?, ?it/s]

100%|██████████| 14/14 [00:09<00:00,  1.48it/s]


In [198]:
def save_all_scalers(graphs, delays, data_path:Path):
    # Node features
    all_nodes = np.vstack([g.x.numpy() for g in graphs])
    scaler_node = StandardScaler()
    scaler_node.fit(all_nodes)
    
    # Edge features
    all_edges = np.vstack([g.edge_attr.numpy() for g in graphs])
    scaler_edge = StandardScaler()
    scaler_edge.fit(all_edges)

    # Delay features
    scaler_delay = StandardScaler()
    scaler_delay.fit(np.concatenate(delays).reshape(-1,1))

    # Combine into a dictionary
    scalers = {
        'node': scaler_node,
        'edge': scaler_edge,
        'delay': scaler_delay,
    }

    # Save everything
    joblib.dump(scalers, data_path / 'scalers.pkl')
    print(f"Saved all scalers to {data_path / 'scalers.pkl'}")

def load_all_scalers(data_path):
    scalers =  joblib.load(data_path / 'scalers.pkl')
    scalers['day'] = lambda x: 2*(x/6)-1
    scalers['sec'] = lambda x: 2*(x/(24*3600))-1
    return scalers

save_all_scalers(graphs_train, delays_train, data_path)

Saved all scalers to /mnt/c/Users/rdsup/Desktop/vitmma19-pw-delay-detection/data/scalers.pkl


In [None]:
trip_idx_list = load_trip_idx_list(graphs_path)
records = load_train_test_split(data_path)

record_name = '20251009'
day = name_to_day(record_name)[0] # 0-6 TODO: tensor
df = load_delays(record_name,delays_path)

trip_id_list = df.index.to_numpy()
sec_list = torch.tensor(df['trip_start'].to_numpy(),dtype=torch.float32)
day_list = torch.ones_like(sec_list,dtype=torch.float32)*day
delay_list = torch.tensor(df['delay'].to_numpy(),dtype=torch.float32)


In [None]:
class MLP(nn.Module):
    def __init__(self, layer_sizes: list[int]):
        super().__init__()
        layers = []
        for i in range(len(layer_sizes)-1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            if i < len(layer_sizes)-2:
                layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

### Mini batches (PyG flattens...)

In [None]:
i = 0
batch_size = 32

day_batch = torch.zeros((batch_size,1), dtype=torch.float)
sec_batch = torch.zeros((batch_size,1), dtype=torch.float)
graphs_batch = []
trip_idx_batch = []
delay_batch = torch.zeros((batch_size,1), dtype=torch.float)

for trip_id, row in df.iterrows():
    route_id, delay, sec = row
    
    try:
        g = load_graph(sec, 1/2, record_name, graphs_path)
    except FileNotFoundError:
        # print(f"Graph not found for trip {trip_id}, sec {sec}")
        continue
    
    graphs_batch.append(g)
    day_batch[i,0] = day
    sec_batch[i,0] = sec
    trip_idx_batch.append(trip_idx_list[trip_id])
    delay_batch[i,0] = delay
    i += 1
    
    # If batch is full, stop loading
    if i == batch_size:
        graphs_batch = Batch.from_data_list(graphs_batch)
        break

# Now you have your batch:
# graphs_batch, day_batch, sec_batch, trip_delay, delay_batch

In [41]:
class GCNWithEdge(nn.Module):
    def __init__(self, layer_sizes: list[int], edge_attr_dim: int):
        super().__init__()
        self.convs = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            nn_edge = nn.Sequential(
                nn.Linear(edge_attr_dim, layer_sizes[i]*layer_sizes[i+1]),
                nn.ReLU(),
                nn.Linear(layer_sizes[i]*layer_sizes[i+1], layer_sizes[i]*layer_sizes[i+1])
            )
            self.convs.append(NNConv(layer_sizes[i], layer_sizes[i+1], nn_edge, aggr='mean'))

    def forward(self, x, edge_index, edge_attr):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index, edge_attr))
        x = self.convs[-1](x, edge_index, edge_attr)
        return x

In [50]:
x = graphs_batch.x
edge_index = graphs_batch.edge_index
edge_attr = graphs_batch.edge_attr
trip_nodes = trip_idx_batch
day = day_batch
sec = sec_batch


gcn = GCNWithEdge(layer_sizes=[7,8,8],edge_attr_dim=2)
mlp = MLP(layer_sizes=[8+2,4,1])

x = gcn(graphs_batch.x, graphs_batch.edge_index, graphs_batch.edge_attr)
print(x.shape)

trip_embeddings = []
for nodes in trip_nodes:
    # nodes_in_trip contains **indices relative to x**, e.g., [12, 45, 78]
    x_trip = x[nodes]  # select the embeddings for this trip
    x_trip = x_trip.mean(dim=0)  # or .sum(dim=0), or torch.max(dim=0)
    trip_embeddings.append(x_trip)

# Stack to [num_trips, embedding_dim]
x = torch.stack(trip_embeddings, dim=0)

print(x.shape)

torch.Size([144032, 8])
torch.Size([32, 8])


In [None]:
class EndOfTripDelay(nn.Module):
    def __init__(self, gcn_layers: list[int], edge_attr_dim: int, mlp_layers: list[int]):
        super().__init__()
        self.GCN = GCNWithEdge(gcn_layers, edge_attr_dim)
        self.MLP = MLP(mlp_layers)

    def forward(self, x, edge_index, edge_attr, trip_nodes, day, sec):

        # GCN
        x = self.GCN(x, edge_index, edge_attr)

        trip_embeddings = []
        for nodes in trip_nodes:
            # nodes_in_trip contains **indices relative to x**, e.g., [12, 45, 78]
            x = x[nodes]  # select the embeddings for this trip
            x = x.mean(dim=0)  # or .sum(dim=0), or torch.max(dim=0)
            trip_embeddings.append(x)

        # Stack to [num_trips, embedding_dim]
        x = torch.stack(trip_embeddings, dim=0)

        # Add day/hour info
        graph_emb = torch.cat([x, day, sec], dim=1)

        # MLP for prediction
        pred = self.MLP(graph_emb)
        return pred


### Without minibatches

In [180]:
class GCNWithEdge(nn.Module):
    def __init__(self, layer_sizes: list[int], edge_attr_dim: int):
        super().__init__()
        self.convs = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            nn_edge = nn.Sequential(
                nn.Linear(edge_attr_dim, layer_sizes[i]*layer_sizes[i+1]),
                nn.ReLU(),
                nn.Linear(layer_sizes[i]*layer_sizes[i+1], layer_sizes[i]*layer_sizes[i+1])
            )
            self.convs.append(NNConv(layer_sizes[i], layer_sizes[i+1], nn_edge, aggr='mean'))

    def forward(self, x, edge_index, edge_attr):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index, edge_attr))
        x = self.convs[-1](x, edge_index, edge_attr)
        return x
    

class EndOfTripDelay(nn.Module):
    def __init__(self, 
            in_dim: int, 
            gcn_dims:list[int], edge_attr_dim: int, 
            embd_dim: int,
            mlp_dims: list[int],
            out_dim: int,
        ):
        super().__init__()
        
        self.GCN = GCNWithEdge(
            [in_dim] + gcn_dims + [embd_dim], 
            edge_attr_dim
        )
        
        self.MLP = MLP(
            [embd_dim + 2] + mlp_dims + [out_dim]
        )

    def forward(self, x, edge_index, edge_attr, trip_nodes, day, sec):

        # GCN
        x = self.GCN(x, edge_index, edge_attr)

        # Average trip pooling
        x = x[trip_nodes].mean(dim=0)

        # Meta embedding
        x = torch.cat([x, day, sec])

        # MLP for prediction
        x = self.MLP(x)
        return x


In [None]:
scalers = load_all_scalers(data_path)

scaler_node = scalers['node']
scaler_edge = scalers['edge']
scaler_delay = scalers['delay']
scaler_day = scalers['day']
scaler_sec = scalers['sec']

#########--ITER--##########
i = 5
trip_id = trip_id_list[i]
###########################

day = day_list[i].unsqueeze(0)
sec = sec_list[i].unsqueeze(0)
try:
    g = load_graph(sec, 1/2, record_name, graphs_path)
    g.edge_attr.detach() # Static edges right now
    trip_nodes = trip_idx_list[trip_id]
except:
    pass # continue

model = EndOfTripDelay(
    in_dim=7,gcn_dims=[32],edge_attr_dim=2,embd_dim=16,mlp_dims=[2],out_dim=1
)


x = scaler_node.transform(g.x)
edge_attr = scaler_edge.transform(g.edge_attr)
day = scaler_day(day)
sec = scaler_sec(sec)

delay = scaler_delay

model(x,g.edge_index,edge_attr, trip_nodes, day, sec)