In [2]:
import pickle
import sys
from zoneinfo import ZoneInfo
sys.path.append("../")

from dotenv import load_dotenv
load_dotenv()
import geopandas as gpd
import importlib
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import lightning.pytorch as pl
import rasterio as rio
from rasterio.plot import show
import seaborn as sns
import shapely
import statsmodels.api as sm
from torch.utils.data import DataLoader

from openbustools import plotting, spatial, standardfeeds
from openbustools.traveltime import data_loader, model_utils
from openbustools.drivecycle import trajectory
from openbustools.drivecycle.physics import conditions, energy, vehicle

In [2]:
cleaned_sources = pd.read_csv(Path('..', 'data', 'cleaned_sources.csv'))
cleaned_sources.head()

Unnamed: 0,country_code,subdivision_name,municipality,provider,static_url,realtime_url,min_lon,max_lon,min_lat,max_lat,uuid,tz_str,epsg_code
0,CA,Ontario,London,London Transit Commission,http://www.londontransit.ca/gtfsfeed/google_tr...,http://gtfs.ltconline.ca/Vehicle/VehiclePositi...,-81.36311,-81.137591,42.905244,43.051188,0fd41f26-6bc5-45f0-88ff-594bde1e8c24,America/Toronto,32617
1,CA,Ontario,Barrie,Barrie Transit,http://www.myridebarrie.ca/gtfs/Google_transit...,http://www.myridebarrie.ca/gtfs/GTFS_VehiclePo...,-79.740632,-79.610896,44.321804,44.420207,3196b95d-1eb2-4e05-bb36-7a2099d2348b,America/Toronto,32617
2,US,California,Santa Monica,Big Blue Bus,http://gtfs.bigbluebus.com/current.zip,http://gtfs.bigbluebus.com/vehiclepositions.bin,-118.549205,-118.237266,33.929498,34.075669,dbeb97e3-012d-4c23-93aa-a80696884340,America/Los_Angeles,32611
3,US,California,Mountain View,Mountain View Transportation Management Associ...,http://data.trilliumtransit.com/gtfs/mountainv...,https://ridemvgo.org/gtfs-rt/vehiclepositions,-122.111591,-122.047584,37.387656,37.431429,6885d33d-2f4e-4ac7-84a7-2e4747c723f0,America/Los_Angeles,32610
4,US,California,Merced,Merced County Transit (The Bus),http://data.trilliumtransit.com/gtfs/mercedthe...,https://thebuslive.com/gtfs-rt/vehiclepositions,-121.02024,-120.248,36.964503,37.521786,ea803659-c28f-43ac-8c01-619d767e0b1d,America/Los_Angeles,32610


In [3]:
all_data = []
for uuid in cleaned_sources.uuid:
    provider_path = Path('..', 'data', 'other_feeds', f"{uuid}_realtime")
    available_files = [x.name for x in provider_path.glob('*.pkl')]
    for fname in available_files:
        data = pd.read_pickle(provider_path / fname)
        data['uuid'] = uuid
        all_data.append(data)
all_data = pd.concat(all_data)
all_data = pd.merge(cleaned_sources[['uuid', 'municipality', 'provider']], all_data, on='uuid')
all_data.head()

Unnamed: 0,uuid,municipality,provider,trip_id,file,locationtime,lat,lon,vehicle_id
0,0fd41f26-6bc5-45f0-88ff-594bde1e8c24,London,London Transit Commission,2007915,2024_01_01,1704116300,43.018391,-81.243179,3155
1,0fd41f26-6bc5-45f0-88ff-594bde1e8c24,London,London Transit Commission,2007915,2024_01_01,1704116331,43.019211,-81.240379,3155
2,0fd41f26-6bc5-45f0-88ff-594bde1e8c24,London,London Transit Commission,2007915,2024_01_01,1704116361,43.020302,-81.236549,3155
3,0fd41f26-6bc5-45f0-88ff-594bde1e8c24,London,London Transit Commission,2007915,2024_01_01,1704116393,43.020969,-81.234413,3155
4,0fd41f26-6bc5-45f0-88ff-594bde1e8c24,London,London Transit Commission,2007915,2024_01_01,1704116423,43.022369,-81.229729,3155


In [4]:
all_data.groupby('provider').count().sort_values('uuid', ascending=False)

Unnamed: 0_level_0,uuid,municipality,trip_id,file,locationtime,lat,lon,vehicle_id
provider,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
King County Metro,1035847,1035847,1035847,1035847,1035847,1035847,1035847,1035847
Roma Servizi per la Mobilità,995006,995006,995006,995006,995006,995006,995006,995006
Massachusetts Bay Transportation Authority (MBTA),868739,868739,868739,868739,868739,868739,868739,868739
Edmonton Transit System,765328,765328,765328,765328,765328,765328,765328,765328
Capital Metro,736330,736330,736330,736330,736330,736330,736330,736330
Port Authority of Allegheny County,688369,688369,688369,688369,688369,688369,688369,688369
Metropolitan Atlanta Rapid Transit Authority (MARTA),412615,412615,412615,412615,412615,412615,412615,412615
Adelaide Metro,364396,364396,364396,364396,364396,364396,364396,364396
Central Florida Regional Transit Authority (LYNX),353448,353448,353448,353448,353448,353448,353448,353448
York Region Transit,300899,300899,300899,300899,300899,300899,300899,300899


In [4]:
# For each city, test torch model on that city, then fine-tune and re-test
model = model_utils.load_model("../logs/", "kcm", "GRU", 0)

model.eval()

# gdf = self.gdf.copy()
# gdf['t_min_of_day'] = self.traj_attr['t_min_of_day']
# gdf['t_day_of_week'] = self.traj_attr['t_day_of_week']
# # Fill modeling features with -1 if not added to trajectory gdf
# for col in data_loader.NUM_FEAT_COLS:
#     if col not in gdf.columns:
#         gdf[col] = -1
# feats_n = gdf[data_loader.NUM_FEAT_COLS].to_numpy().astype('int32')
# return {0: {'feats_n': feats_n}}

# data_loader.normalize_samples(samples, model.config)
# dataset = data_loader.H5Dataset(samples)
# if model.is_nn:
#     loader = DataLoader(
#         dataset,
#         collate_fn=model.collate_fn,
#         batch_size=model.batch_size,
#         shuffle=False,
#         drop_last=False,
#         num_workers=0,
#         pin_memory=False
#     )
#     trainer = pl.Trainer(
#         accelerator='cpu',
#         logger=False,
#         inference_mode=True,
#         enable_progress_bar=False,
#         enable_checkpointing=False,
#         enable_model_summary=False
#     )
#     preds_and_labels = trainer.predict(model=model, dataloaders=loader)
# else:
#     preds_and_labels = model.predict(dataset, 'h')
# preds = [x['preds_raw'].flatten() for x in preds_and_labels][0]
# preds[0] = 0
# self.pred_time_s = preds
# self.gdf['cumul_time_s'] = np.cumsum(preds)

GRU(
  (loss_fn): MSELoss()
  (min_em): MinuteEmbedding(
    (em): Embedding(1440, 48)
  )
  (day_em): DayEmbedding(
    (em): Embedding(7, 4)
  )
  (rnn): GRU(4, 64, num_layers=2, dropout=0.05)
  (feature_extract): Linear(in_features=116, out_features=1, bias=True)
  (feature_extract_activation): ReLU()
)

In [None]:
test_data, holdout_routes, test_config = data_loader.load_h5(args.train_data_folders, test_dates, holdout_routes=model.holdout_routes, config=model.config)
test_dataset = data_loader.H5Dataset(test_data)
test_dataset.include_grid = model.include_grid
test_loader = DataLoader(
    test_dataset,
    collate_fn=model.collate_fn,
    batch_size=model.batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
    pin_memory=pin_memory
)
trainer = pl.Trainer(
    accelerator=accelerator,
    logger=False,
    inference_mode=True
)
preds_and_labels = trainer.predict(model=model, dataloaders=test_loader)
preds = np.concatenate([x['preds'] for x in preds_and_labels])
labels = np.concatenate([x['labels'] for x in preds_and_labels])