In [None]:
%load_ext autoreload
%autoreload 2

In [13]:
import pandas as pd
import numpy as np
import logging

logger = logging.getLogger(__name__)

In [None]:
from lewidi_lib import load_dataset, enable_logging

enable_logging()

datasets = ["CSC", "MP"]
splits = ["train", "dev"]
ddf = pd.concat([load_dataset(d, split=s) for d in datasets for s in splits])
print(len(ddf))
ddf.head(2)

In [None]:
import duckdb
from lewidi_lib import process_rdf


con = duckdb.connect()
rdf = con.sql("SELECT * FROM read_parquet('../parquets/*.parquet')").df()
rdf = process_rdf(rdf)
rdf.head(2)

In [22]:
run_info_cols = ["run_id", "model_id", "model_size", "gen_kwargs", "dataset", "split"]
run_info = (
    rdf.groupby(run_info_cols, as_index=False)
    .agg(
        is_valid_pred=("is_valid_pred", "mean"),
        n_output_tokens=("n_output_tokens", "mean"),
    )
    .sort_values(["model_size"])
)
# run_info.round(2)

In [23]:
grun_info = (
    run_info.groupby([c for c in run_info_cols if c != "run_id"])
    .agg({"is_valid_pred": "mean", "n_output_tokens": "mean"})
    .sort_values(["model_size"])
)
# grun_info.round(3)

In [None]:
import seaborn as sns

run_info["model_size"] = run_info["model_size"].astype(str)
g = sns.relplot(
    data=run_info,
    x="model_size",
    y="is_valid_pred",
    hue="gen_kwargs",
    col="dataset",
    row="split",
    kind="line",
    marker="o",
    height=3,
    aspect=1.2,
)
g.set_axis_labels("Model Params [B]", "Proportion of Valid Predictions")
g.legend.set_title("Reasoning")
for ax in g.axes.flat:
    ax.set_ylim(0, 1.05)
    ax.grid(alpha=0.5)


In [None]:
logger.info(
    "Dropping %d predictions that don't sum to 1", len(rdf.query("~is_valid_pred"))
)
rdf.query("is_valid_pred", inplace=True)

In [None]:
joint_df = pd.merge(
    ddf[["dataset", "split", "request_idx", "target"]],
    rdf,
    on=["dataset", "split", "request_idx"],
)
joint_df.head(2)

In [43]:
from lewidi_lib import assign_col_ws_loss, assign_col_l0_loss

joint_df = assign_col_l0_loss(joint_df)
joint_df = assign_col_ws_loss(joint_df)


In [None]:
run_info_cols

In [None]:
perf_metrics_df = (
    joint_df.groupby(
        ["run_id", "model_id", "model_size", "gen_kwargs", "dataset", "split"], as_index=False
    )
    .agg(
        avg_ws_loss=("ws_loss", "mean"),
        # std_ws_loss=("ws_loss", "std"),
        avg_n_output_tokens=("n_output_tokens", "mean"),
        # avg_l0_loss=("l0_loss", "mean"),
        # std_l0_loss=("l0_loss", "std"),
    )
    .sort_values(["model_size", "gen_kwargs"])
)
# perf_metrics_df.round(2)

# Baseline: Uniform Distribution

In [46]:
from lewidi_lib import assign_n_classes, baseline_pred


bdf = assign_n_classes(ddf)
bdf = bdf.assign(pred=lambda row: row["n_classes"].apply(baseline_pred))
bdf = assign_col_ws_loss(bdf)
bdf = assign_col_l0_loss(bdf)

In [None]:
baseline_losses = bdf.groupby(["dataset", "split"], as_index=False).agg(
    {"ws_loss": "mean", "l0_loss": "mean"}
)
baseline_losses

In [None]:
import seaborn as sns

perf_metrics_df["model_size"] = perf_metrics_df["model_size"].astype(str)
g = sns.relplot(
    data=perf_metrics_df,
    x="model_size",
    y="avg_ws_loss",
    hue="gen_kwargs",
    col="dataset",
    row="split",
    kind="line",
    marker="o",
    height=3,
    aspect=1.5,
)
g.set(ylim=(0, None))
# csc_baseline_ws_loss = baseline_losses.query("dataset_name == 'CSC'")["ws_loss"].values[0]
# g.ax.axhline(csc_baseline_ws_loss, color="red", linestyle="--", label="Baseline")
g.legend.set_title("Reasoning")
g.set_axis_labels("Model Params [B]", "Wasserstein Distance")
for ax in g.axes.flat:
    ax.grid(alpha=0.5)

In [None]:
g = sns.relplot(
    data=perf_metrics_df.assign(model_size=perf_metrics_df["model_size"].astype("float")),
    x="avg_n_output_tokens",
    y="avg_ws_loss",
    hue="model_size",
    style="gen_kwargs",
    col="dataset",
    row="split",
    kind="scatter",
    height=3,
    aspect=1.5,
)
for ax in g.axes.flat:
    ax.grid(alpha=0.5)
g.legend.set_title("Model Size")
g.set_axis_labels("Avg Output Tokens", "Wasserstein Distance")