In [None]:
import pickle
import sys
from zoneinfo import ZoneInfo
sys.path.append("../")

from dotenv import load_dotenv
load_dotenv()
import geopandas as gpd
import importlib
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import lightning.pytorch as pl
import rasterio as rio
from rasterio.plot import show
import seaborn as sns
import shapely
import statsmodels.api as sm
import torch
from torch.utils.data import DataLoader

from openbustools import plotting, spatial, standardfeeds
from openbustools.traveltime import data_loader, model_utils
from openbustools.drivecycle import trajectory
from openbustools.drivecycle.physics import conditions, energy, vehicle

from srai.embedders import GTFS2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import GTFSLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
from srai.plotting import plot_regions, plot_numeric_data

In [None]:
if torch.cuda.is_available():
    num_workers=4
    pin_memory=True
    accelerator="cuda"
else:
    num_workers=0
    pin_memory=False
    accelerator="cpu"

logging.getLogger("lightning").setLevel(logging.ERROR)
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.ERROR)

train_days = standardfeeds.get_date_list('2024_01_04', 1)
train_days = [x.split(".")[0] for x in train_days]
test_days = standardfeeds.get_date_list('2024_01_05', 1)
test_days = [x.split(".")[0] for x in test_days]

cleaned_sources = pd.read_csv(Path('..', 'data', 'cleaned_sources.csv'))

In [None]:
# Test model on each city
res = {}
test_sources = cleaned_sources.iloc[:]
for i, row in test_sources.iterrows():
    # Load model and check if day has data for given city
    model = model_utils.load_model("../logs/", "kcm", "GRU", 0)
    model.eval()
    test_data_folders = [f"../data/other_feeds/{row['uuid']}_realtime/processed"]
    available_days = [x.name for x in Path(test_data_folders[0], "training").glob('*')]
    if test_days[0] not in available_days:
        continue
    # Run inference for city
    print(row['provider'])
    test_dataset = data_loader.NumpyDataset(
        test_data_folders,
        test_days,
        holdout_routes=model.holdout_routes,
        load_in_memory=True,
        config = model.config,
    )
    test_loader = DataLoader(
        test_dataset,
        collate_fn=model.collate_fn,
        batch_size=model.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    trainer = pl.Trainer(
        accelerator=accelerator,
        logger=False,
        inference_mode=True
    )
    preds_and_labels = trainer.predict(model=model, dataloaders=test_loader)
    preds = np.concatenate([x['preds'] for x in preds_and_labels])
    labels = np.concatenate([x['labels'] for x in preds_and_labels])
    res[row['uuid']] = {'preds':preds, 'labels':labels}

In [None]:
# Tune, then re-test
n_batches = 1
res_tuned = {}
test_sources = cleaned_sources.iloc[:]
for i, row in test_sources.iterrows():
    # Load model and check if day has data for given city
    data_folders = [f"../data/other_feeds/{row['uuid']}_realtime/processed"]
    available_days = [x.name for x in Path(data_folders[0], "training").glob('*')]
    if train_days[0] not in available_days:
        continue
    print(row['provider'])
    res_tuned[row['uuid']] = {}
    # Try increasing amounts of training samples per-city
    model = model_utils.load_model("../logs/", "kcm", "GRU", 0)
    model.eval()
    train_dataset = data_loader.NumpyDataset(
        data_folders,
        train_days,
        load_in_memory=True,
        config = model.config
    )
    train_loader = DataLoader(
        train_dataset,
        collate_fn=model.collate_fn,
        batch_size=model.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    trainer = pl.Trainer(
        max_epochs=10,
        limit_train_batches=n_batches,
        accelerator=accelerator
    )
    trainer.fit(model=model, train_dataloaders=train_loader)
    # Test after fine-tuning
    test_dataset = data_loader.NumpyDataset(
        data_folders,
        test_days,
        load_in_memory=True,
        config = model.config
    )
    train_loader = DataLoader(
        train_dataset,
        collate_fn=model.collate_fn,
        batch_size=model.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    trainer = pl.Trainer(
        accelerator=accelerator,
        logger=False,
        inference_mode=True
    )
    preds_and_labels = trainer.predict(model=model, dataloaders=test_loader)
    preds = np.concatenate([x['preds'] for x in preds_and_labels])
    labels = np.concatenate([x['labels'] for x in preds_and_labels])
    res_tuned[row['uuid']][f"{n_batches}_batches"] = {'preds':preds, 'labels':labels}

In [None]:
metrics = {}
for key, value in res.items():
    metrics[key] = model_utils.performance_metrics(value['labels'], value['preds'])
metrics_df = pd.DataFrame(metrics).T
metrics_df.index.name = 'uuid'
metrics_df = metrics_df.reset_index()
metrics_df = pd.merge(metrics_df, cleaned_sources, on='uuid')
metrics_df['experiment'] = 'not-tuned'
metrics_df['n_batches'] = '0_batches'
metrics_df[['provider','n_batches','experiment','mae', 'mape', 'rmse', 'ex_var', 'r_score']].sort_values(['provider', 'n_batches']).head(10)

In [None]:
metrics = {}
for k_uuid, res_uuid in res_tuned.items():
    for k_n_batches, res_n_batches in res_uuid.items():
        metrics[(k_uuid, k_n_batches)] = model_utils.performance_metrics(res_n_batches['labels'], res_n_batches['preds'])
metrics_tuned_df = pd.DataFrame(metrics).T
metrics_tuned_df.index.names = ['uuid', 'n_batches']
metrics_tuned_df = metrics_tuned_df.reset_index()
metrics_tuned_df = pd.merge(metrics_tuned_df, cleaned_sources, on='uuid')
metrics_tuned_df['experiment'] = 'tuned'
metrics_tuned_df[['provider','n_batches','experiment','mae', 'mape', 'rmse', 'ex_var', 'r_score']].sort_values(['provider', 'n_batches']).head(10)

In [None]:
plot_df = pd.concat([metrics_df, metrics_tuned_df], axis=0)
plot_df = plot_df[plot_df['n_batches'].isin(['0_batches', '1_batches', '2_batches', '4_batches'])]
plot_df['Tuning Sample Size'] = plot_df['n_batches'].replace({'0_batches': 'No Tuning', '1_batches': '1000 Samples', '2_batches': '2000 Samples', '4_batches': '4000 Samples'})
plot_df['MAPE'] = plot_df['mape']
sns.histplot(plot_df, x='MAPE', hue='Tuning Sample Size', bins=50, kde=True)