In [29]:
import torch
import numpy as np
import json
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [30]:
# ---------------------------------------------------------
# 1. SETUP & IMPORTS
# ---------------------------------------------------------

# Add src to path to import our custom modules
base_dir = "."
if base_dir not in sys.path:
    sys.path.append(base_dir)

# Import our model and functions
# (Make sure your src folder is in the same directory)
from src.models.model import STGCN
from src.utils.utils import z_score, z_score_inverse
from src.data.data_processing import process_adjacency, sequence_data, split_data

# Define available routes (Same as in create_data.ipynb)
available_routes = [
    ["SWA", "NTH", "PTA", "BGN", "CDF", "NWP", "BPW", "SWI", "DID", "CHO",
     "GOR", "PAN", "TLH", "RDG", "TWY", "MAI", "BNM", "SLO", "LNY",
     "IVR", "WDT", "HAY", "STL", "EAL", "PAD"],
    ["WSM", "WNM", "WOR", "YAT", "NLS", "BRI", "BTH", 
     "CPM", "SWI", "DID", "CHO", "GOR", "PAN", "TLH", "RDG", "TWY", "MAI",
     "BNM", "SLO", "LNY", "IVR", "WDT", "HAY", "STL", "EAL", "PAD"],
    ["BAN", "KGS", "HYD", "TAC", "OXF", "RAD", "CUM", "APF", "DID", "CHO",
     "GOR", "PAN", "TLH", "RDG", "TWY", "MAI", "BNM", "SLO", "LNY",
     "IVR", "WDT", "HAY", "STL", "EAL", "PAD"]
]

def get_links(stop_1, stop_2, available_routes):
    """Finds all intermediate links between two stations."""
    done = 0
    link_list = []
    for i in range(len(available_routes)):
        curr_route = available_routes[i]
        if (stop_1 in curr_route) and (stop_2 in curr_route) and (not done):
            done = 1
            stop_1_idx = curr_route.index(stop_1)
            stop_2_idx = curr_route.index(stop_2)            
            for j in range(stop_1_idx, stop_2_idx):                
                link_list.append(curr_route[j] + curr_route[j+1])
    return link_list

In [31]:
#Parameters (Must match main.py)
n_nodes = 40
n_timesteps_in = 12

# --- CRITICAL CHANGE HERE ---
# Changed from 1 to 6 to match your new "Smart" model
n_features_in = 6  
# ----------------------------

n_features_out = 1
ks = 5
kt = 3
drop_prob = 0.0
approx = "cheb_poly"
blocks = [[n_features_in, 32, 64], [64, 32, 128], [128, n_features_out]]

# --- Device Setup ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using device: CUDA")
else:
    device = torch.device("cpu")
    print("Using Device: CPU")

# --- 1. Load Model ---
model_path = "./models/checkpoints/optimized_model.model"
stats_path = "./models/checkpoints/output_stats.json"

# Load Model
if not os.path.exists(model_path):
    print("Error: Model not found. Please run main.py first.")
else:
    model = STGCN(blocks, n_timesteps_in, n_features_out, n_nodes, device, ks, kt, drop_prob).to(device).double()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"Model loaded successfully (Expects {n_features_in} features).")

# --- 2. Load Stats ---
if not os.path.exists(stats_path):
    print("Error: Stats file not found.")
    output_stats = {"mean": 0, "std": 1} # Fallback
else:
    with open(stats_path, 'r') as f:
        output_stats = json.load(f)
    print(f"Stats loaded: Mean={output_stats['mean']:.4f}, Std={output_stats['std']:.4f}")

# --- 3. Load Graph Kernel (Lk) ---
data_dir = "./data/processed/"
Lk = process_adjacency(data_dir, "raildelays", ks, n_nodes, approx, device)
print("Graph kernel (Lk) loaded.")

# --- 4. Load Mappings ---
mappings = {}
mapping_files = ["G_stop2idx.json", "G_idx2stop.json", "LG_node2idx.json", "LG_idx2node.json", "LG_node2label.json"]
for f_name in mapping_files:
    with open(os.path.join(data_dir, f_name), 'r') as f:
        if f_name == "LG_idx2node.json":
            mappings[f_name.split('.')[0]] = {int(k): v for k, v in json.load(f).items()}
        else:
            mappings[f_name.split('.')[0]] = json.load(f)

# Helper to convert keys
def parse_tuple_key(key_str):
    try: return eval(key_str)
    except: return key_str

mappings["LG_node2label"] = {parse_tuple_key(k): v for k, v in mappings["LG_node2label"].items()}
mappings["LG_node2idx"] = {parse_tuple_key(k): v for k, v in mappings["LG_node2idx"].items()}

print("All files loaded and ready.")

Using Device: CPU
Model loaded successfully (Expects 6 features).
Stats loaded: Mean=0.0734, Std=0.8663
Graph kernel (Lk) loaded.
All files loaded and ready.


In [32]:
# ---------------------------------------------------------
# 3. GET TEST SAMPLE (With Fixed Seed for Stability)
# ---------------------------------------------------------
#np.random.seed(42) # Un-comment to freeze results

dataset_seq = sequence_data(data_dir, n_nodes, 42, n_timesteps_in, 1, n_features_in)
_, _, data_test, _ = split_data(dataset_seq, n_timesteps_in)
X_test, y_test = data_test

# Pick a random sample from test set
idx = np.random.randint(0, len(X_test))
X_sample = X_test[idx:idx+1].permute(0, 3, 1, 2).to(device)
y_sample = y_test[idx:idx+1].permute(0, 3, 1, 2)

print(f"Loaded test sample index: {idx}")

Calculated Stats from Training Data: Mean=0.0755, Std=0.8373
Loaded test sample index: 1883


In [33]:
# ---------------------------------------------------------
# 4. PREDICTION FUNCTION (The Core Logic)
# ---------------------------------------------------------

def predict_route_delay(route_list):
    # A. Identify Links in Route
    links = []
    link_indices = []
    
    for i in range(len(route_list)-1):
        segment_links = get_links(route_list[i], route_list[i+1], available_routes)
        links.extend(segment_links)
        
    if not links:
        return None, "No valid links found between these stations."

    # Find indices for these links
    for link in links:
        found = False
        for key, label in mappings["LG_node2label"].items():
            if label == link:
                link_indices.append(mappings["LG_node2idx"][key])
                found = True
                break
        if not found:
            print(f"Warning: Link {link} not in graph.")

    # B. Run Prediction
    with torch.no_grad():
        y_pred = model(X_sample, Lk)
    
    # C. Inverse Scale (Back to Minutes)
    pred_mins = z_score_inverse(y_pred.cpu(), output_stats["mean"], output_stats["std"]).numpy().squeeze()
    actual_mins = z_score_inverse(y_sample.cpu(), output_stats["mean"], output_stats["std"]).numpy().squeeze()
    
    # D. Extract Data for Route
    route_pred = [pred_mins[i] for i in link_indices]
    route_actual = [actual_mins[i] for i in link_indices]
    
    # E. Calculate Station Delays (Cumulative)
    # We start at 0 delay for the first station
    station_pred = np.cumsum(route_pred)
    station_actual = np.cumsum(route_actual)
    error = station_pred - station_actual
    
    # F. Create DataFrame
    df = pd.DataFrame({
        "Link": links,
        "Pred Link Delay": route_pred,
        "Actual Link Delay": route_actual,
        "Pred Station Delay": station_pred,
        "Actual Station Delay": station_actual,
        "Error": error
    })
    
    return df, None

In [34]:
# ---------------------------------------------------------
# 5. VISUALIZATION FUNCTION (Plotly)
# ---------------------------------------------------------

def show_results(df, route_str):
    # 1. Station Delay Plot
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=df["Link"], y=df["Pred Station Delay"], mode='lines+markers', name='Predicted Delay', line=dict(color='blue', width=3)))
    fig.add_trace(go.Scatter(x=df["Link"], y=df["Actual Station Delay"], mode='lines+markers', name='Actual Delay', line=dict(color='red', width=3, dash='dash')))
    fig.update_layout(title=f"Station-Wise Delay: {route_str}", xaxis_title="Link (Station to Station)", yaxis_title="Total Delay (min)", template="plotly_white")
    
    # 2. Error Plot
    fig_err = go.Figure()
    fig_err.add_trace(go.Bar(x=df["Link"], y=df["Error"], name="Prediction Error", marker_color="purple"))
    fig_err.update_layout(title="Prediction Error per Station", xaxis_title="Link", yaxis_title="Error (min)", template="plotly_white")
    
    # 3. Display
    fig.show()
    fig_err.show()
    
    # 4. Data Table with Styling
    # Use .map instead of .applymap for newer pandas versions
    def color_val(val):
        color = 'red' if abs(val) > 2 else 'green'
        return f'color: {color}'
    
    display(df.style.format("{:.2f} min", subset=["Pred Link Delay", "Actual Link Delay", "Pred Station Delay", "Actual Station Delay", "Error"])
            .map(color_val, subset=["Error"]))

In [35]:
# ---------------------------------------------------------
# 6. USER INTERFACE
# ---------------------------------------------------------

text_input = widgets.Text(
    value='SWA,NTH,PTA,BGN,CDF,NWP,BPW,SWI,DID,RDG,PAD',
    placeholder='Type station codes (e.g., SWA,PAD)',
    description='<b>Route:</b>',
    layout=widgets.Layout(width='80%')
)

btn = widgets.Button(
    description='Predict',
    button_style='primary',
    icon='train'
)

out = widgets.Output()

def on_click(b):
    with out:
        clear_output()
        route = [x.strip().upper() for x in text_input.value.split(',')]
        if len(route) < 2:
            print("Error: Enter at least 2 stations.")
            return
            
        print(f"Processing route: {' -> '.join(route)}...")
        df_res, err = predict_route_delay(route)
        
        if err:
            print(err)
        else:
            show_results(df_res, text_input.value)

btn.on_click(on_click)

display(widgets.VBox([text_input, btn, out]))

VBox(children=(Text(value='SWA,NTH,PTA,BGN,CDF,NWP,BPW,SWI,DID,RDG,PAD', description='<b>Route:</b>', layout=Lâ€¦