# Defs

In [122]:
import torch
import joblib
import json
import pandas as pd
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from pathlib import Path
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.nn import NNConv, global_mean_pool
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from utils.time import name_to_day
from data_processing import (
    load_train_test_split,
    load_graph,
    load_delays,
    load_stop_id_to_idx,
    load_trip_idx_list,
    available_bin_codes
)

from train import (
    load_scalers, 
    load_train_data
)
root = Path('/mnt/c/Users/rdsup/Desktop/vitmma19-pw-delay-detection')

data_path = root / 'data'
delays_path = data_path / 'delays'
graphs_path = data_path / 'graphs'

In [17]:
trip_idx_list = load_trip_idx_list(graphs_path)
records = load_train_test_split(data_path)

record_name = '20251011'
day = name_to_day(record_name)[0] # 0-6 TODO: tensor
df = load_delays(record_name,delays_path)
bin_codes = available_bin_codes(record_name, graphs_path)

dt = 1/2
bins = np.arange(0,24+dt,dt)
bins *= 60*60

df['bin'] = pd.cut(df['trip_start'], bins=bins, right=False)
df['bin_code'] = df['bin'].cat.codes
df = df[df['bin_code'].isin(bin_codes)]

In [18]:
trip_id_list = df.index.to_numpy()
sec_list = torch.tensor(df['trip_start'].to_numpy(),dtype=torch.float32)
day_list = torch.ones_like(sec_list,dtype=torch.float32)*day
delay_list = torch.tensor(df['delay'].to_numpy(),dtype=torch.float32)

In [None]:
(
    x_train, 
    edge_attr_train, 
    delays_train
) = load_train_data(data_path)
scalers = load_scalers(data_path)

100%|██████████| 14/14 [00:12<00:00,  1.08it/s]


### Mini batches (PyG flattens...)

In [None]:
i = 0
batch_size = 32

day_batch = torch.zeros((batch_size,1), dtype=torch.float)
sec_batch = torch.zeros((batch_size,1), dtype=torch.float)
graphs_batch = []
trip_idx_batch = []
delay_batch = torch.zeros((batch_size,1), dtype=torch.float)

for trip_id, row in df.iterrows():
    route_id, delay, sec = row
    
    try:
        g = load_graph(sec, 1/2, record_name, graphs_path)
    except FileNotFoundError:
        # print(f"Graph not found for trip {trip_id}, sec {sec}")
        continue
    
    graphs_batch.append(g)
    day_batch[i,0] = day
    sec_batch[i,0] = sec
    trip_idx_batch.append(trip_idx_list[trip_id])
    delay_batch[i,0] = delay
    i += 1
    
    # If batch is full, stop loading
    if i == batch_size:
        graphs_batch = Batch.from_data_list(graphs_batch)
        break

# Now you have your batch:
# graphs_batch, day_batch, sec_batch, trip_delay, delay_batch

In [41]:
class GCNWithEdge(nn.Module):
    def __init__(self, layer_sizes: list[int], edge_attr_dim: int):
        super().__init__()
        self.convs = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            nn_edge = nn.Sequential(
                nn.Linear(edge_attr_dim, layer_sizes[i]*layer_sizes[i+1]),
                nn.ReLU(),
                nn.Linear(layer_sizes[i]*layer_sizes[i+1], layer_sizes[i]*layer_sizes[i+1])
            )
            self.convs.append(NNConv(layer_sizes[i], layer_sizes[i+1], nn_edge, aggr='mean'))

    def forward(self, x, edge_index, edge_attr):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index, edge_attr))
        x = self.convs[-1](x, edge_index, edge_attr)
        return x

In [50]:
x = graphs_batch.x
edge_index = graphs_batch.edge_index
edge_attr = graphs_batch.edge_attr
trip_nodes = trip_idx_batch
day = day_batch
sec = sec_batch


gcn = GCNWithEdge(layer_sizes=[7,8,8],edge_attr_dim=2)
mlp = MLP(layer_sizes=[8+2,4,1])

x = gcn(graphs_batch.x, graphs_batch.edge_index, graphs_batch.edge_attr)
print(x.shape)

trip_embeddings = []
for nodes in trip_nodes:
    # nodes_in_trip contains **indices relative to x**, e.g., [12, 45, 78]
    x_trip = x[nodes]  # select the embeddings for this trip
    x_trip = x_trip.mean(dim=0)  # or .sum(dim=0), or torch.max(dim=0)
    trip_embeddings.append(x_trip)

# Stack to [num_trips, embedding_dim]
x = torch.stack(trip_embeddings, dim=0)

print(x.shape)

torch.Size([144032, 8])
torch.Size([32, 8])


In [None]:
class EndOfTripDelay(nn.Module):
    def __init__(self, gcn_layers: list[int], edge_attr_dim: int, mlp_layers: list[int]):
        super().__init__()
        self.GCN = GCNWithEdge(gcn_layers, edge_attr_dim)
        self.MLP = MLP(mlp_layers)

    def forward(self, x, edge_index, edge_attr, trip_nodes, day, sec):

        # GCN
        x = self.GCN(x, edge_index, edge_attr)

        trip_embeddings = []
        for nodes in trip_nodes:
            # nodes_in_trip contains **indices relative to x**, e.g., [12, 45, 78]
            x = x[nodes]  # select the embeddings for this trip
            x = x.mean(dim=0)  # or .sum(dim=0), or torch.max(dim=0)
            trip_embeddings.append(x)

        # Stack to [num_trips, embedding_dim]
        x = torch.stack(trip_embeddings, dim=0)

        # Add day/hour info
        graph_emb = torch.cat([x, day, sec], dim=1)

        # MLP for prediction
        pred = self.MLP(graph_emb)
        return pred


# Model

In [114]:
class MLP(nn.Module):
    def __init__(self, layer_sizes: list[int]):
        super().__init__()
        layers = []
        for i in range(len(layer_sizes)-1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            if i < len(layer_sizes)-2:
                layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class GCNWithEdge(nn.Module):
    def __init__(self, layer_sizes: list[int], edge_attr_dim: int):
        super().__init__()
        self.convs = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            nn_edge = nn.Sequential(
                nn.Linear(edge_attr_dim, layer_sizes[i]*layer_sizes[i+1]),
                nn.ReLU(),
                nn.Linear(layer_sizes[i]*layer_sizes[i+1], layer_sizes[i]*layer_sizes[i+1])
            )
            self.convs.append(NNConv(layer_sizes[i], layer_sizes[i+1], nn_edge, aggr='mean'))

    def forward(self, x, edge_index, edge_attr):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index, edge_attr))
        x = self.convs[-1](x, edge_index, edge_attr)
        return x
    

class EndOfTripDelay(nn.Module):
    def __init__(self, 
            in_dim: int, 
            gcn_dims:list[int], edge_attr_dim: int, 
            embd_dim: int,
            mlp_dims: list[int],
            out_dim: int,
        ):
        super().__init__()
        
        self.GCN = GCNWithEdge(
            [in_dim] + gcn_dims + [embd_dim], 
            edge_attr_dim
        )
        
        self.MLP = MLP(
            [embd_dim + 2] + mlp_dims + [out_dim]
        )

    def forward(self, x, edge_index, edge_attr, trip_nodes, day, sec):

        # GCN
        x = self.GCN(x, edge_index, edge_attr)

        # Average trip pooling
        x = x[trip_nodes].mean(dim=0)

        # Meta embedding
        x = torch.cat([x, day, sec])

        # MLP for prediction
        x = self.MLP(x)
        return x


In [None]:
model = EndOfTripDelay(
    in_dim=7,gcn_dims=[32],edge_attr_dim=2,embd_dim=16,mlp_dims=[2],out_dim=1
)

#########--TRAIN ITER--##########
i = 5
trip_id = trip_id_list[i]
###########################

day = day_list[i].unsqueeze(0)
sec = sec_list[i].unsqueeze(0)
delay = delay_list[i].unsqueeze(0)
try:
    g = load_graph(sec, 1/2, record_name, graphs_path)
    g.edge_attr.detach() # Static edges right now
    trip_nodes = trip_idx_list[trip_id]
except:
    pass # continue

x = scalers['x'].transform(g.x)
edge_attr = scalers['edge_attr'].transform(g.edge_attr)
edge_index = g.edge_index

day = scalers['day'].transform(day)
sec = scalers['sec'].transform(sec)
delay = scalers['delay'].transform(delay)

model(x, edge_index, edge_attr, trip_nodes, day, sec)

tensor([-0.3155], grad_fn=<ViewBackward0>)

# Dataset

In [21]:
records = load_train_test_split(data_path)
record_name = '20251011'

### Single graph in cache

In [None]:
class TripDataset(torch.utils.data.Dataset):
    def __init__(self, record_name, data_path:Path, scalers:dict):
        
        self.scalers = scalers
        self.record_name = record_name
        self.graphs_path = data_path / 'graphs'

        # Nodes of the trips, by trip_id
        self.nodes_of_trips = load_trip_idx_list(self.graphs_path)
        
        # Which day?
        day = name_to_day(record_name)[0]
        
        # Load and filter (by available bins) delays
        df = load_delays(record_name,delays_path)
        df = self.filter_delays(df,dt=1/2)

        # Training data
        self.bin_codes = df['bin_code'].to_numpy()
        self.bin_code = None
        
        self.trip_ids = df.index.to_numpy()

        self.secs = scalers['sec'].transform(
            torch.tensor(df['trip_start'].to_numpy(),dtype=torch.float32).unsqueeze(-1)
        )
        
        day_tensor = torch.tensor([[day]], dtype=torch.float32)
        self.days = scalers['day'].transform(day_tensor) * torch.ones_like(self.secs)
        
        self.delays = scalers['delay'].transform(
            torch.tensor(df['delay'].to_numpy(),dtype=torch.float32).unsqueeze(-1)
        )
    
    def __getitem__(self, idx):
        
        trip_id = self.trip_ids[idx]
        trip_nodes = self.nodes_of_trips[trip_id]

        sec = self.secs[idx]
        day = self.days[idx]
        delay = self.delays[idx]
        
        if self.bin_codes[idx] != self.bin_code:
            self.bin_code = self.bin_codes[idx]

            # Load the graph
            self.g = load_graph(
                self.bin_code, 
                dt=1/2, 
                record_name=self.record_name, 
                graphs_path=self.graphs_path
            )
            self.x = self.scalers['x'].transform(self.g.x)
            self.edge_attr = self.scalers['edge_attr'].transform(self.g.edge_attr)
            self.edge_index = self.g.edge_index

        return {
            "x": self.x,
            "edge_index": self.edge_index,
            "edge_attr": self.edge_attr,
            "trip_nodes": trip_nodes,
            "day": day,
            "sec": sec,
            "delay": delay,
        }
    
    def __len__(self):
        return len(self.trip_ids)

    def filter_delays(self, df:pd.DataFrame, dt=1/2):
        bin_codes = available_bin_codes(
            self.record_name, self.graphs_path
        )
        bins = np.arange(0,24+dt,dt)
        bins *= 60*60

        df['bin'] = pd.cut(df['trip_start'], bins=bins, right=False)
        df['bin_code'] = df['bin'].cat.codes
        return df[df['bin_code'].isin(bin_codes)].sort_values('bin_code') # .groupby('bin_code')

In [108]:
scalers = load_scalers(data_path)
dataset = TripDataset(
    record_name='20251011',
    data_path=data_path,
    scalers=scalers
)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [109]:
for data in tqdm(loader):
    pass

  0%|          | 0/8884 [00:00<?, ?it/s]

100%|██████████| 8884/8884 [00:01<00:00, 5138.74it/s]


# Full graph cache

In [None]:
class TripDataset(torch.utils.data.Dataset):
    def __init__(self, record_name, data_path:Path, scalers:dict):
        
        self.scalers = scalers
        self.record_name = record_name
        self.graphs_path = data_path / 'graphs'

        # Load graphs for the record
        self.load_graph_cache() # before filtering delays!!!
        
        # Nodes of the trips, by trip_id
        self.nodes_of_trips = load_trip_idx_list(self.graphs_path)
        
        # Which day?
        day = name_to_day(record_name)[0]
        
        # Load and filter (by available bins) delays
        df = load_delays(record_name,delays_path)
        df = self.filter_delays(df,dt=1/2)

        # Training data
        self.bin_codes = df['bin_code'].to_numpy()
        self.trip_ids = df.index.to_numpy()

        self.secs = scalers['sec'].transform(
            torch.tensor(df['trip_start'].to_numpy(),dtype=torch.float32).unsqueeze(-1)
        )
        
        day_tensor = torch.tensor([[day]], dtype=torch.float32)
        self.days = scalers['day'].transform(day_tensor) * torch.ones_like(self.secs)
        
        self.delays = scalers['delay'].transform(
            torch.tensor(df['delay'].to_numpy(),dtype=torch.float32).unsqueeze(-1)
        )
    
    def __getitem__(self, idx):
        
        g = self.g_cache[self.bin_codes[idx]]

        data = Data(
            x=g.x,
            edge_index=g.edge_index,
            edge_attr=g.edge_attr,
            y=self.delays[idx],
            day=self.days[idx],
            sec=self.secs[idx],
            trip_nodes=self.nodes_of_trips[self.trip_ids[idx]]
        )
        return data
    
    def __len__(self):
        return len(self.trip_ids)


    def load_graph_cache(self):
        self.g_cache = {}
        for f in (self.graphs_path / self.record_name).iterdir():
            if f.is_file() and f.name.startswith("graph_bin_") and f.suffix == ".pt":
                bin_code = int(f.stem.replace("graph_bin_", ""))
                g = torch.load(f, weights_only=False)
                g.x = self.scalers['x'].transform(g.x)
                g.edge_attr = self.scalers['edge_attr'].transform(g.edge_attr)
                self.g_cache[bin_code] = g
            

    def filter_delays(self, df:pd.DataFrame, dt=1/2):
        bins = np.arange(0,24+dt,dt)
        bins *= 60*60

        df['bin'] = pd.cut(df['trip_start'], bins=bins, right=False)
        df['bin_code'] = df['bin'].cat.codes
        return df[df['bin_code'].isin(self.g_cache.keys())]

In [124]:
scalers = load_scalers(data_path)
dataset = TripDataset(
    record_name='20251011',
    data_path=data_path,
    scalers=scalers
)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [125]:
for data in tqdm(loader):
    pass

100%|██████████| 8884/8884 [00:04<00:00, 2084.73it/s]


# Train???

In [126]:
model = EndOfTripDelay(
    in_dim=7,
    gcn_dims=[32],
    edge_attr_dim=2,
    embd_dim=16,
    mlp_dims=[2],
    out_dim=1
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

EndOfTripDelay(
  (GCN): GCNWithEdge(
    (convs): ModuleList(
      (0): NNConv(7, 32, aggr=mean, nn=Sequential(
        (0): Linear(in_features=2, out_features=224, bias=True)
        (1): ReLU()
        (2): Linear(in_features=224, out_features=224, bias=True)
      ))
      (1): NNConv(32, 16, aggr=mean, nn=Sequential(
        (0): Linear(in_features=2, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
      ))
    )
  )
  (MLP): MLP(
    (net): Sequential(
      (0): Linear(in_features=18, out_features=2, bias=True)
      (1): ReLU()
      (2): Linear(in_features=2, out_features=1, bias=True)
    )
  )
)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    for batch in tqdm(loader):
        # Move batch to device
        batch = batch.to(device)

        optimizer.zero_grad()
        
        # Forward pass
        # Assume your model accepts x, edge_index, edge_attr, trip_nodes, day, sec
        pred = model(
            batch.x,
            batch.edge_index,
            batch.edge_attr,
            batch.trip_nodes,
            batch.day,
            batch.sec
        )
        
        # Compute loss
        loss = criterion(pred, batch.y)
        loss.backward()
        optimizer.step()
    
        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss/len(loader):.4f}")