In [None]:
import polars as pl
from glob import glob
import os
import seaborn as sb


def get_dfs_from_glob(glob_str: str):
    dfs = [
        pl.read_csv(i, dtypes={'stopid': pl.Utf8, 'routeid': pl.Utf8})
        for i in glob(glob_str)
    ]
    return pl.concat(dfs, how='vertical').unique()


def load_raw_data():
    """
    Load raw data from existing parquet file or from CSVs.
    """
    df = (pl.read_parquet('raw.parquet')
          if os.path.exists('raw.parquet')
          else get_dfs_from_glob('raw_data/raw_trip_*').unique()
          )
    # Save memory by re-using strings for categorical data
    return df.cast({
        'period': pl.Categorical,
        'routeid': pl.Categorical,
        'stopid': pl.Categorical
    })


MAX_DELAY = 60*30


def plot_stop(dfs: pl.DataFrame, stopid: str):
    cats = (
        dfs
        .filter(pl.col('delay').abs() < MAX_DELAY)
        .filter(pl.col('stopid') == stopid)
        .group_by('id')
        .all()
        .select('id', 'delay', 'lastupdate')
        .with_columns(pl.col('delay').list.len().alias('length'))
        .explode('delay', 'lastupdate')
        .sort('id', 'lastupdate')
        .with_columns(pl.col('id').cast(pl.Utf8).cast(pl.Categorical))
    )
    sb.scatterplot(cats, x='lastupdate', y='delay', hue='id')

In [None]:
raw_dfs = load_raw_data()

In [None]:
# Loading from the CSVs take like 30 seconds!!!
# Save into a more compact form for easier retrieval later
raw_dfs.write_parquet('raw.parquet')

So we know that all trip id is unique for a single day, no need to worry about overlaps

Average delay in a stop every 3 minutes 10 recordings of a bus

AM: 7am-9am
PM: 4pm-7pm
OFF: 5am-7am, 9am-4-pm, 7pm-10pm

In [None]:
MAX_DELAY = 60 * 20
delay = pl.col('delay')
aggregated = (
    raw_dfs
    .with_columns(delay.abs())
    .with_columns(
        pl.from_epoch('lastupdate', time_unit='s')
        # MST
        .dt.offset_by('-7h')
    )
    .with_columns(
        pl.col('lastupdate').dt.date().alias('date'),
        pl.col('lastupdate').sort()
    )
    .group_by('id', 'routeid', 'stopid', 'date', 'period')
    .agg(
        delay.max().alias('maxdelay'),
        delay.mean().alias('meandelay'),
        delay.median().alias('mediandelay'),
        delay.std().alias('stddelay'),
        delay
    )
    # Just remove trips that are above MAX_DELAY
    .filter(pl.col('meandelay') < MAX_DELAY)
    .sort('mediandelay')
)

How many counts should we deem useful for visualization of data?

In [None]:
pl.Config.set_fmt_table_cell_list_len(100)
aggregated
# 1_250_916