# Traffic Forecasting Visualizations

Creates visualizations comparing ground truth (geographic graph) vs learned graph:
1. Node degree comparison
2. Graph adjacency for selected nodes (animated)
3. 6-hour rush hour traffic predictions (animated)

In [None]:
from dotenv import load_dotenv

if not load_dotenv(".env_jupyter"):
    raise RuntimeError("specified .env file not found")

from datetime import datetime
from pathlib import Path

import numpy as np

import utils
from utils.inference import get_learned_adjacency_matrix, get_model_predictions_cached
from utils.visual import (
    animate_traffic_heatmap,
    create_geographic_nodes_animation,
    create_node_degree_comparison,
    select_geographically_dispersed_nodes,
)

# Enable autoreload to automatically reload modules when they change
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Configuration

In [None]:
# Configuration
DATASETS = ["METR-LA", "PEMS-BAY"]  # Process both datasets
MODEL_PREFIX = "STGFORMER_FINAL"  # Final model (short, 20 epochs)
EXPERIMENT_NAME = "final_k16"  # Short name for saved artifacts
OUTPUT_DIR = Path("../docs/img")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Animation settings
ANIMATION_FPS = (
    1  # Frames per second (e.g., 0.5 = 2 seconds per frame, 1.0 = 1 second per frame)
)

# CVD-friendly mode (Color Vision Deficiency)
CVD_FRIENDLY = (
    True  # Set to True to use CVD-safe colormaps (viridis instead of plasma/RdYlGn)
)

# Basemap settings
BASEMAP_SOURCE = "CartoDB.Positron"  # Options: "CartoDB.Voyager", "CartoDB.Positron", "OpenStreetMap.Mapnik"

print(f"Model: {MODEL_PREFIX}")
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Datasets: {', '.join(DATASETS)}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Animation FPS: {ANIMATION_FPS}")
print(f"CVD-friendly: {CVD_FRIENDLY}")
print(f"Basemap: {BASEMAP_SOURCE}")

## Generate Visualizations for Each Dataset

In [None]:
for DATASET_NAME in DATASETS:
    print(f"\n{'=' * 60}")
    print(f"Processing {DATASET_NAME}")
    print(f"{'=' * 60}\n")

    # Load graph metadata
    adj_mx_gt, _, locations = utils.io.get_graph_metadata(DATASET_NAME)
    num_sensors = len(locations)
    print(f"Number of sensors: {num_sensors}")

    # Find the 7am index FIRST so we only compute predictions for that window
    print("\nFinding 7am rush hour window...")
    dataset_hf = utils.io.get_dataset_hf(DATASET_NAME)
    test_df = dataset_hf["test"].to_pandas()

    # Get timestamps from node 0
    node_ids = sorted(test_df["node_id"].unique())
    timestamps_df = test_df[test_df["node_id"] == node_ids[0]]
    test_timestamps = timestamps_df["t0_timestamp"].values

    test_ts_parsed = [datetime.fromisoformat(ts) for ts in test_timestamps]
    morning_7am_indices = [
        idx for idx, ts in enumerate(test_ts_parsed) if ts.hour == 7 and ts.minute == 0
    ]

    if len(morning_7am_indices) > 0:
        start_idx = morning_7am_indices[0]
    else:
        morning_indices = [
            idx for idx, ts in enumerate(test_ts_parsed) if 6 <= ts.hour <= 8
        ]
        start_idx = morning_indices[0] if morning_indices else 0

    # Extract 6 hours of continuous data (72 timesteps = 6 samples)
    num_samples_for_6h = 72 // 12
    end_idx = start_idx + num_samples_for_6h

    print(f"  Will use samples {start_idx} to {end_idx} (6 samples for 6-hour window)")

    # Load model predictions ONLY for the needed window
    predictions, ground_truth = get_model_predictions_cached(
        dataset_name=DATASET_NAME,
        hf_repo_prefix=MODEL_PREFIX,
        sample_indices=(start_idx, end_idx),  # Only compute what we need!
    )
    print(f"Predictions shape: {predictions.shape}")

    # Load learned adjacency matrix
    adj_mx_model = get_learned_adjacency_matrix(
        dataset_name=DATASET_NAME,
        hf_repo_prefix=MODEL_PREFIX,
    )
    print(f"Learned adjacency shape: {adj_mx_model.shape}")

    # Compute node degrees for node selection
    degree_gt = np.abs(adj_mx_gt).sum(axis=1)

    # Select geographically dispersed nodes
    selected_nodes = select_geographically_dispersed_nodes(locations, degree_gt)
    print(f"Selected {len(selected_nodes)} geographically dispersed nodes")

    # 1. Node Degree Comparison
    print("\n1. Creating node degree comparison...")
    output_path = (
        OUTPUT_DIR
        / f"{EXPERIMENT_NAME}_{DATASET_NAME.lower()}_node_degree_comparison.png"
    )
    create_node_degree_comparison(
        adj_mx_ground_truth=adj_mx_gt,
        adj_mx_model=adj_mx_model,
        locations=locations,
        output_path=output_path,
        dataset_name=DATASET_NAME,
        cvd_friendly=CVD_FRIENDLY,
        basemap_source=BASEMAP_SOURCE,
    )

    # Display the image
    from IPython.display import Image, display

    display(Image(filename=str(output_path)))

    # 2. Geographic Nodes Animation
    print("\n2. Creating geographic nodes animation...")
    output_path = (
        OUTPUT_DIR
        / f"{EXPERIMENT_NAME}_{DATASET_NAME.lower()}_geographic_nodes_comparison.gif"
    )
    frame_duration = 1.0 / ANIMATION_FPS  # Convert FPS to seconds per frame
    create_geographic_nodes_animation(
        adj_mx_ground_truth=adj_mx_gt,
        adj_mx_model=adj_mx_model,
        locations=locations,
        output_path=output_path,
        selected_nodes=selected_nodes,
        frame_seconds=frame_duration,
        cvd_friendly=CVD_FRIENDLY,
        basemap_source=BASEMAP_SOURCE,
    )

    # Display the GIF
    display(Image(filename=str(output_path)))

    # 3. Rush Hour Traffic Animation
    print("\n3. Creating rush hour traffic animation...")

    # Reshape to continuous timesteps (predictions and ground_truth already have only the 6 samples)
    gt_6h = ground_truth.reshape(-1, num_sensors, ground_truth.shape[-1])
    pred_6h = predictions.reshape(-1, num_sensors, predictions.shape[-1])

    # Generate time labels
    start_time = test_ts_parsed[start_idx]
    date_str = start_time.strftime("%Y-%m-%d")
    time_labels_6h = [
        f"{date_str} {(start_time.hour + (i * 5) // 60) % 24:02d}:{((start_time.minute + i * 5) % 60):02d}"
        for i in range(gt_6h.shape[0])
    ]

    print(f"  Time window: {time_labels_6h[0]} -> {time_labels_6h[-1]}")
    print(f"  Shape: {gt_6h.shape}")

    output_path = (
        OUTPUT_DIR
        / f"{EXPERIMENT_NAME}_{DATASET_NAME.lower()}_6h_rush_hour_comparison.gif"
    )
    animate_traffic_heatmap(
        values=gt_6h[:, :, 0],  # Take first feature (speed)
        locations=locations,
        output_path=output_path,
        fps=ANIMATION_FPS,
        duration_seconds=gt_6h.shape[0] / ANIMATION_FPS,
        timestamps=time_labels_6h,
        figsize=(18, 8),
        cmap="RdYlGn_r",
        values_comparison=pred_6h[:, :, 0],
        title_left="Ground Truth",
        title_right="Model Prediction",
        cvd_friendly=CVD_FRIENDLY,
        basemap_source=BASEMAP_SOURCE,
    )

    # Display the GIF
    display(Image(filename=str(output_path)))

    print(f"\nâœ“ Completed visualizations for {DATASET_NAME}")

## Summary

Generated 3 visualizations for each dataset:
1. **Node Degree Comparison** (PNG): Shows which sensors have the most connections
2. **Geographic Nodes Animation** (GIF): Shows graph adjacency for selected sensors
3. **6-Hour Rush Hour** (GIF): Compares ground truth vs model predictions over time

All outputs saved to `outputs/visualizations/`