In [None]:
import pickle
import shutil
import sys
from zoneinfo import ZoneInfo
sys.path.append("../")

from dotenv import load_dotenv
load_dotenv()
import geopandas as gpd
import importlib
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import lightning.pytorch as pl
import rasterio as rio
from rasterio.plot import show
import seaborn as sns
import shapely
import torch
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, SequentialSampler, SubsetRandomSampler
from sklearn.model_selection import KFold

from openbustools import plotting, spatial, standardfeeds
from openbustools.traveltime import data_loader, model_utils
from openbustools.drivecycle import trajectory
from openbustools.drivecycle.physics import conditions, energy, vehicle

from srai.embedders import GTFS2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import GTFSLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
from srai.plotting import plot_regions, plot_numeric_data

In [None]:
if torch.cuda.is_available():
    num_workers=4
    pin_memory=True
    accelerator="cuda"
else:
    num_workers=0
    pin_memory=False
    accelerator="cpu"

logging.getLogger("lightning").setLevel(logging.ERROR)
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.ERROR)

model_network = "kcm"
model_type = "GRU"

train_days = standardfeeds.get_date_list('2024_01_05', 5)
train_days = [x.split(".")[0] for x in train_days]
test_days = standardfeeds.get_date_list('2024_02_06', 7)
test_days = [x.split(".")[0] for x in test_days]

cleaned_sources = pd.read_csv(Path('..', 'data', 'cleaned_sources.csv'))

In [None]:
# # Test base model and heuristic on each city
# res_base = {}
# res_avg = {}
# test_sources = cleaned_sources.iloc[:]
# for i, row in test_sources.iterrows():

#     # Load base model
#     model = model_utils.load_model("../logs/", model_network, model_type, 0)
#     model.eval()
#     test_data_folders = [f"../data/other_feeds/{row['uuid']}_realtime/processed"]
#     available_days = [x.name for x in Path(test_data_folders[0], "training").glob('*')]
#     if test_days[0] not in available_days:
#         continue

#     # Test inference for city
#     print(row['provider'])
#     test_dataset = data_loader.NumpyDataset(
#         test_data_folders,
#         test_days,
#         holdout_routes=model.holdout_routes,
#         load_in_memory=True,
#         config = model.config,
#     )
#     test_loader = DataLoader(
#         test_dataset,
#         collate_fn=model.collate_fn,
#         batch_size=model.batch_size,
#         shuffle=False,
#         drop_last=False,
#         num_workers=num_workers,
#         pin_memory=pin_memory
#     )
#     trainer = pl.Trainer(
#         accelerator=accelerator,
#         logger=False,
#         inference_mode=True,
#         enable_progress_bar=False
#     )
#     preds_and_labels = trainer.predict(model=model, dataloaders=test_loader)
#     preds = np.concatenate([x['preds'] for x in preds_and_labels])
#     labels = np.concatenate([x['labels'] for x in preds_and_labels])
#     res_base[row['uuid']] = {'preds':preds, 'labels':labels}

#     # Load and test heuristic
#     model = pickle.load(open(Path("../logs/", model_network, "AVG-0.pkl"), 'rb'))
#     preds_and_labels = model.predict(test_dataset)
#     res_avg[row['uuid']] = {'preds':preds_and_labels['preds'], 'labels':preds_and_labels['labels']}

In [None]:
# save_dir = Path("..", "results", model_network, "multicity_tuning")
# if save_dir.exists():
#     shutil.rmtree(save_dir)
# save_dir.mkdir(parents=True, exist_ok=True)
# with open(save_dir / f"{model_type}.pkl", "wb") as f:
#     pickle.dump(res_base, f)
# with open(save_dir / f"AVG.pkl", "wb") as f:
#     pickle.dump(res_avg, f)

In [None]:
# # Tune, then re-test the base model on increasing number of data samples
# n_batches = [1, 10, 100, 500, 1000]
# batch_size = 10
# res_tuned = {}
# test_sources = cleaned_sources.iloc[:]

# for i, row in test_sources.iterrows():
#     # Load model and check if day has data for given city
#     data_folders = [f"../data/other_feeds/{row['uuid']}_realtime/processed"]
#     available_days = [x.name for x in Path(data_folders[0], "training").glob('*')]
#     if train_days[0] not in available_days:
#         print(f"Skipping {row['provider']}")
#         continue
#     print(row['provider'])
#     res_tuned[row['uuid']] = {}
#     for j, batch_limit in enumerate(n_batches):
#         print(f"Training with {batch_limit} batches")
#         # Try increasing amounts of training samples per-city
#         model = model_utils.load_model("../logs/", model_network, model_type, 0)
#         model.train()
#         train_dataset = data_loader.NumpyDataset(
#             data_folders,
#             train_days,
#             load_in_memory=True,
#             config=model.config
#         )
#         k_fold = KFold(5, shuffle=True, random_state=42)
#         train_idx, val_idx = list(k_fold.split(np.arange(len(train_dataset))))[0]
#         train_sampler = SubsetRandomSampler(train_idx)
#         val_sampler = SequentialSampler(val_idx)
#         train_loader = DataLoader(
#             train_dataset,
#             collate_fn=model.collate_fn,
#             batch_size=batch_size,
#             sampler=train_sampler,
#             drop_last=True,
#             num_workers=num_workers,
#             pin_memory=pin_memory,
#         )
#         val_loader = DataLoader(
#             train_dataset,
#             collate_fn=model.collate_fn,
#             batch_size=batch_size,
#             sampler=val_sampler,
#             drop_last=True,
#             num_workers=num_workers,
#             pin_memory=pin_memory,
#         )
#         trainer = pl.Trainer(
#             check_val_every_n_epoch=1,
#             max_epochs=100,
#             accelerator=accelerator,
#             callbacks=[EarlyStopping(monitor=f"valid_loss", min_delta=.0001, patience=3)],
#             limit_train_batches=batch_limit,
#             enable_progress_bar=False
#         )
#         trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

#         # Test after fine-tuning
#         model.eval()
#         test_dataset = data_loader.NumpyDataset(
#             data_folders,
#             test_days,
#             load_in_memory=True,
#             config=model.config
#         )
#         test_loader = DataLoader(
#             test_dataset,
#             collate_fn=model.collate_fn,
#             batch_size=batch_size,
#             shuffle=False,
#             drop_last=False,
#             num_workers=num_workers,
#             pin_memory=pin_memory
#         )
#         trainer = pl.Trainer(
#             accelerator=accelerator,
#             logger=False,
#             inference_mode=True,
#             enable_progress_bar=False
#         )
#         preds_and_labels = trainer.predict(model=model, dataloaders=test_loader)
#         preds = np.concatenate([x['preds'] for x in preds_and_labels])
#         labels = np.concatenate([x['labels'] for x in preds_and_labels])
#         res_tuned[row['uuid']][f"{batch_limit}_batches"] = {'preds':preds, 'labels':labels}

In [None]:
# save_dir = Path("..", "results", model_network, "multicity_tuning")
# save_dir.mkdir(parents=True, exist_ok=True)
# with open(save_dir / f"{model_type}_TUNED.pkl", "wb") as f:
#     pickle.dump(res_tuned, f)

In [None]:
res_base = pickle.load(open(Path("..", "results", model_network, "multicity_tuning", f"{model_type}.pkl"), "rb"))
res_avg = pickle.load(open(Path("..", "results", model_network, "multicity_tuning", "AVG.pkl"), "rb"))
res_tuned = pickle.load(open(Path("..", "results", model_network, "multicity_tuning", f"{model_type}_TUNED.pkl"), "rb"))

In [None]:
metrics = {}
for k_uuid, res_uuid in res_base.items():
    metrics[k_uuid] = model_utils.performance_metrics(res_uuid['labels'], res_uuid['preds'])
metrics_base_df = pd.DataFrame(metrics).T
metrics_base_df.index.names = ['uuid']
metrics_base_df['n_batches'] = '0_batches'
metrics_base_df = metrics_base_df.reset_index()
metrics_base_df = pd.merge(metrics_base_df, cleaned_sources, on='uuid')
metrics_base_df['experiment'] = 'not_tuned'
metrics_base_df[['uuid','provider','n_batches','experiment','mae', 'mape', 'rmse', 'ex_var', 'r_score']].sort_values(['provider', 'n_batches']).head(10)

In [None]:
metrics = {}
for k_uuid, res_uuid in res_tuned.items():
    for k_n_batches, res_n_batches in res_uuid.items():
        metrics[(k_uuid, k_n_batches)] = model_utils.performance_metrics(res_n_batches['labels'], res_n_batches['preds'])
metrics_tuned_df = pd.DataFrame(metrics).T
metrics_tuned_df.index.names = ['uuid', 'n_batches']
metrics_tuned_df = metrics_tuned_df.reset_index()
metrics_tuned_df = pd.merge(metrics_tuned_df, cleaned_sources, on='uuid')
metrics_tuned_df['experiment'] = 'tuned'
metrics_tuned_df[['uuid','provider','n_batches','experiment','mae', 'mape', 'rmse', 'ex_var', 'r_score']].sort_values(['provider', 'n_batches']).head(10)

In [None]:
all_metrics = pd.concat([metrics_base_df, metrics_tuned_df]).sort_values(['provider', 'n_batches'])
all_metrics['n_batches'] = pd.Categorical(all_metrics['n_batches'], ['0_batches','1_batches','10_batches','100_batches','500_batches','1000_batches','avg'])
all_metrics['Tuning Sample Size'] = all_metrics['n_batches'].replace({'0_batches': 'No Tuning', '1_batches':'10 Samples', '10_batches':'100 Samples', '100_batches':'1,000 Samples', '500_batches':'5,000 Samples', '1000_batches':'10,000 Samples', 'avg':'Heuristic'})
all_metrics['MAPE'] = all_metrics['mape']

In [None]:
fig, axes = plt.subplots(1,1)
fig.set_figheight(5)
fig.set_figwidth(8)
sns.boxplot(ax=axes, data=all_metrics, x='MAPE', hue='Tuning Sample Size', palette=plotting.PALETTE)
axes.set_xlim([.08, .3])
axes.set_title('Tuned Performance for 33 International Cities')
axes.legend(handles=axes.get_legend_handles_labels()[0], loc='upper left', ncol=1)
axes.set_xticklabels(['10%', '12.5%', '15%', '17.5%', '20%', '22.5%', '25%', '27.5%', '30%'])
fig.savefig("../plots/multicity_generalization.png", dpi=300)

In [None]:
metrics = {}
for k_uuid, res_uuid in res_avg.items():
    metrics[k_uuid] = model_utils.performance_metrics(res_uuid['labels'], res_uuid['preds'])
metrics_avg_df = pd.DataFrame(metrics).T
metrics_avg_df.index.names = ['uuid']
metrics_avg_df['n_batches'] = 'avg'
metrics_avg_df = metrics_avg_df.reset_index()
metrics_avg_df = pd.merge(metrics_avg_df, cleaned_sources, on='uuid')
metrics_avg_df['experiment'] = 'heuristic'
metrics_avg_df[['uuid','provider','n_batches','experiment','mae', 'mape', 'rmse', 'ex_var', 'r_score']].sort_values(['provider', 'n_batches']).head(10)