In [None]:
from pathlib import Path

import altair as alt
import numpy as np
import pandas as pd
import polars as pl

In [None]:
color_ampm_domain = ["AM", "PM"]
color_ampm_range = ["#8cb7c9", "#d3d655"]

years = range(2019, 2024, 2)
auto_filepaths = {
    # y: rf"Q:\CMP\LOS Monitoring 2023\Auto_LOS_and_Reliability\CMP{y}_Auto_LOS_and_Reliability.csv"
    # for y in years
    2019: r"Q:\CMP\LOS Monitoring 2021\Auto_LOS\CMP2019_Auto_Speeds_Reliability.csv",
    2021: r"Q:\CMP\LOS Monitoring 2023\Auto_LOS_and_Reliability\CMP2021_Auto_LOS_and_Reliability.csv",
    2023: r"Q:\CMP\LOS Monitoring 2023\Auto_LOS_and_Reliability\CMP2023_Auto_LOS_and_Reliability.csv",
}
transit_filepaths = {
    2019: Path(
        r"Q:\CMP\LOS Monitoring 2019\Transit\Speed\SF_CMP_Transit_Speeds_2019_Final.csv"
    ),
    2021: Path(
        r"Q:\CMP\LOS Monitoring 2021\Transit\Speed\CMP2021_APC_Transit_Speeds_Final.csv"
    ),
    2023: Path(
        r"Q:\CMP\LOS Monitoring 2023\transit\volume_and_speed\2304-2305\Muni-APC-Transit_Speeds-2023.csv"
    ),
}
save_dir = Path(
    r"Q:\CMP\reports\CMPSF 2023\Draft\figures\multimodal_performance\auto-transit-speed-diffs"
)
save_dir.mkdir(exist_ok=True)
# segment IDs 1-245 are the officially defined CMP segments
cmp_segid_filter = pl.col("cmp_segid") < 246


def output_filepath_stem(save_dir, subtrahend_year, minuend_year):
    return (
        save_dir / "auto-transit-speed-diffs-scatter-"
        f"{subtrahend_year}-{minuend_year}"
    )

In [None]:
def read_dfs(filepaths, years):
    # column "source" not needed for this comparison
    # TODO rename "period" to "peak period" for more clarity
    dtypes = {"cmp_segid": int, "year": int, "period": str, "avg_speed": float}
    return {
        y: pl.read_csv(
            filepaths[y], columns=list(dtypes.keys()), dtypes=dtypes
        )
        for y in years
    }


def dfs_to_long(dfs, filter):
    return pl.concat(dfs.values()).filter(filter)


def df_long_to_wide(df_long, years):
    """convert a long df to a wide df (each year as a separate column)
    for scatter chart plotting

    Parameters
    ----------
    df_long : _type_
        _description_
    years : _type_
        only needed because the rename after a pivot isn't working

    Returns
    -------
    _type_
        _description_
    """
    return (
        df_long.pivot(
            index=["cmp_segid", "period"],  # ignore "source"
            columns="year",
            values="avg_speed",
            aggregate_function=None,
            separator="-",
        )
        # the rename shouldn't be needed after a pivot
        # but unclear why it's not working
        .rename({str(y): f"avg_speed-{y}" for y in years}).sort(
            "period", "cmp_segid"
        )
    )


def df_diff_cols(df_wide, subtrahend_year, minuend_year, mode):
    """_summary_
    minuend_year - subtrahend_year = difference

    Parameters
    ----------
    df_wide : _type_
        _description_
    minuend_year : _type_
        _description_
    subtrahend_year : _type_
        _description_
    """
    return df_wide.select(
        "cmp_segid",
        "period",
        (
            pl.col(f"avg_speed-{minuend_year}")
            - pl.col(f"avg_speed-{subtrahend_year}")
        ).alias(f"{mode}-avg_speed_diff-{subtrahend_year}-{minuend_year}"),
    ).drop_nulls()


def make_auto_transit_diffs_df(
    auto_df_wide, transit_df_wide, subtrahend_year, minuend_year
):
    """_summary_
    minuend_year - subtrahend_year = difference

    Parameters
    ----------
    minuend_year : _type_
        _description_
    subtrahend_year : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    return df_diff_cols(
        auto_df_wide, subtrahend_year, minuend_year, "auto"
    ).join(
        df_diff_cols(
            transit_df_wide, subtrahend_year, minuend_year, "transit"
        ),
        on=["cmp_segid", "period"],
        how="inner",
    )


def calculate_chart_extents(diffs_df, x_col, y_col):
    # calculate largest magnitude difference (i.e. max absolute value
    # of the difference) to set x/y limits of the chart
    max_abs_diff = diffs_df.select(
        pl.max_horizontal(pl.col(x_col).abs().max(), pl.col(y_col).abs().max())
    ).item()
    chart_extent = np.ceil(max_abs_diff)  # use integers
    scale_domain = (-chart_extent, chart_extent)  # x/y limits of the chart
    return scale_domain


def plot_scatter(diffs_df, subtrahend_year, minuend_year, save_dir):
    """_summary_
    minuend_year - subtrahend_year = difference

    Parameters
    ----------
    diffs_df : _type_
        _description_
    subtrahend_year : _type_
        _description_
    minuend_year : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    x_col = f"transit-avg_speed_diff-{subtrahend_year}-{minuend_year}"
    y_col = f"auto-avg_speed_diff-{subtrahend_year}-{minuend_year}"
    chart_scatter = (
        alt.Chart(diffs_df)
        .mark_circle(size=20)
        .encode(
            alt.X(f"{x_col}:Q").title(
                f"difference in transit speed ({subtrahend_year} to {minuend_year})"
            ),
            # .scale(domain=scale_domain),
            alt.Y(f"{y_col}:Q").title(
                f"difference in automobile speed ({subtrahend_year} to {minuend_year})"
            ),
            # .scale(domain=scale_domain),
            color=alt.Color("period:N").scale(
                domain=color_ampm_domain, range=color_ampm_range
            ),
            tooltip=["cmp_segid:O", "period:N", f"{x_col}:Q", f"{y_col}:Q"],
        )
    )
    scale_domain = calculate_chart_extents(diffs_df, x_col, y_col)
    chart_vertical = (
        alt.Chart(pd.DataFrame({x_col: (0, 0), y_col: scale_domain}))
        .mark_line(color="grey", opacity=0.5)
        .encode(
            alt.X(x_col),
            alt.Y(y_col),
        )
    )
    chart_horizontal = (
        alt.Chart(pd.DataFrame({x_col: scale_domain, y_col: (0, 0)}))
        .mark_line(color="grey", opacity=0.5)
        .encode(
            alt.X(x_col),
            alt.Y(y_col),
        )
    )
    chart = chart_scatter + chart_vertical + chart_horizontal
    chart.save(
        f"{output_filepath_stem(save_dir, subtrahend_year, minuend_year)}.png",
        scale_factor=2,
    )
    return chart.interactive()


def _x_col(subtrahend_year, minuend_year):
    return f"transit-avg_speed_diff-{subtrahend_year}-{minuend_year}"


def _y_col(subtrahend_year, minuend_year):
    return f"auto-avg_speed_diff-{subtrahend_year}-{minuend_year}"


def calculate_quadrant_totals(diffs_df, x_col, y_col):
    quadrant_totals_dfs = []
    for period in ["AM", "PM", "both peaks"]:
        if period in ["AM", "PM"]:
            filter = (pl.col("period")==period)
        else:
            filter = True
        quadrant_totals_dfs.append(diffs_df.filter(filter).select(
            pl.lit(period).alias("period"),
            ((pl.col(x_col) >= 0) & (pl.col(y_col) >= 0)).sum().alias("Q1"),
            ((pl.col(x_col) < 0) & (pl.col(y_col) >= 0)).sum().alias("Q2"),
            ((pl.col(x_col) < 0) & (pl.col(y_col) < 0)).sum().alias("Q3"),
            ((pl.col(x_col) >= 0) & (pl.col(y_col) < 0)).sum().alias("Q4"),
        ))
    quadrant_totals_df = pl.concat(quadrant_totals_dfs)
    return quadrant_totals_df


def compare_years_auto_transit_diffs(
    auto_df_wide, transit_df_wide, subtrahend_year, minuend_year, save_dir
):
    diffs_df = make_auto_transit_diffs_df(
        auto_df_wide, transit_df_wide, subtrahend_year, minuend_year
    )
    diffs_df.write_csv(
        f"{output_filepath_stem(save_dir, subtrahend_year, minuend_year)}.csv"
    )
    return plot_scatter(diffs_df, subtrahend_year, minuend_year, save_dir)

In [None]:
auto_dfs = read_dfs(auto_filepaths, years)
auto_df_long = dfs_to_long(auto_dfs, cmp_segid_filter)
auto_df_wide = df_long_to_wide(auto_df_long, years)
transit_dfs = read_dfs(transit_filepaths, years)
transit_df_long = dfs_to_long(transit_dfs, cmp_segid_filter)
transit_df_wide = df_long_to_wide(transit_df_long, years)

In [None]:
# TODO merge with df with CMP segment names, and show on tooltip instead of ID

In [None]:
diffs_df = make_auto_transit_diffs_df(
        auto_df_wide, transit_df_wide, 2021, 2023
    )

In [None]:
diffs_df

In [None]:
x_col = "transit-avg_speed_diff-2021-2023"
y_col = "auto-avg_speed_diff-2021-2023"

In [None]:
compare_years_auto_transit_diffs(
    auto_df_wide, transit_df_wide, 2019, 2023, save_dir
)

In [None]:
compare_years_auto_transit_diffs(
    auto_df_wide, transit_df_wide, 2021, 2023, save_dir
)