In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
import logging

logger = logging.getLogger(__name__)

In [None]:
from lewidi_lib import load_dataset, enable_logging

enable_logging()

datasets = ["CSC", "MP"]
ddf = pd.concat([load_dataset(d, split="train") for d in datasets])
print(len(ddf))
ddf.head(2)

In [None]:
import duckdb
from lewidi_lib import process_rdf


con = duckdb.connect()
rdf = con.sql("SELECT * FROM read_parquet('../parquets/*.parquet')").df()
rdf["dataset_name"] = "CSC"
rdf, reasoning_by_run = process_rdf(rdf)
rdf.head(2)

In [None]:
run_info = (
    rdf.groupby(["run_id", "model_id", "model_size"], as_index=False)
    .agg(
        is_valid_pred=("is_valid_pred", "mean"),
        n_output_tokens=("n_output_tokens", "mean"),
    )
    .merge(reasoning_by_run, on="run_id", how="left")
    .sort_values(["model_size", "is_reasoning"])
)
run_info.round(2)

In [None]:
grun_info = (
    run_info.groupby(["model_id", "model_size", "is_reasoning"])
    .agg({"is_valid_pred": "mean", "n_output_tokens": "mean"})
    .sort_values(["model_size", "is_reasoning"])
)
grun_info.round(3)

In [None]:
import seaborn as sns

run_info["model_size"] = run_info["model_size"].astype(str)
ax = sns.lineplot(
    run_info,
    x="model_size",
    y="is_valid_pred",
    hue="is_reasoning",
    marker="o",
)
ax.set_ylabel("Proportion of Valid Predictions")
ax.set_xlabel("Model Params [B]")
ax.legend(title="Reasoning")
ax.set_ylim(0, 1.05)
ax.grid(alpha=0.5)


In [None]:
logger.info(
    "Dropping %d predictions that don't sum to 1", len(rdf.query("~is_valid_pred"))
)
rdf.query("is_valid_pred", inplace=True)

In [None]:
joint_df = pd.merge(
    ddf[["dataset_name", "request_idx", "target"]],
    rdf,
    on=["dataset_name", "request_idx"],
)
joint_df.head(2)

In [11]:
from lewidi_lib import assign_col_ws_loss, assign_col_l0_loss

joint_df = assign_col_l0_loss(joint_df)
joint_df = assign_col_ws_loss(joint_df)


In [None]:
perf_metrics_df = (
    joint_df.groupby(
        ["run_id", "model_id", "model_size", "is_reasoning"], as_index=False
    )
    .agg(
        avg_ws_loss=("ws_loss", "mean"),
        # std_ws_loss=("ws_loss", "std"),
        avg_n_output_tokens=("n_output_tokens", "mean"),
        # avg_l0_loss=("l0_loss", "mean"),
        # std_l0_loss=("l0_loss", "std"),
    )
    .sort_values(["model_size", "is_reasoning"])
)
perf_metrics_df.round(2)

# Baseline: Uniform Distribution

In [13]:
from lewidi_lib import assign_n_classes, baseline_pred


bdf = assign_n_classes(ddf)
bdf = bdf.assign(pred=lambda row: row["n_classes"].apply(baseline_pred))
bdf = assign_col_ws_loss(bdf)
bdf = assign_col_l0_loss(bdf)

In [None]:
baseline_losses = bdf.groupby("dataset_name", as_index=False).agg(
    {"ws_loss": "mean", "l0_loss": "mean"}
)
baseline_losses

In [None]:
import seaborn as sns

perf_metrics_df["model_size"] = perf_metrics_df["model_size"].astype(str)
ax = sns.lineplot(
    perf_metrics_df,
    x="model_size",
    y="avg_ws_loss",
    hue="is_reasoning",
    marker="o",  # Add markers to points
)
ax.set_ylim(0, None)
csc_baseline_ws_loss = baseline_losses.query("dataset_name == 'CSC'")["ws_loss"].values[
    0
]
ax.axhline(csc_baseline_ws_loss, color="red", linestyle="--", label="Baseline")
ax.legend(title="Reasoning")  # Show legend
ax.set_ylabel("Wasserstein Distance (Lower is Better)")
ax.set_xlabel("Model Params [B]")
ax.grid(alpha=0.5)

In [None]:
import matplotlib.pyplot as plt

ax = sns.scatterplot(
    perf_metrics_df.assign(model_size=perf_metrics_df["model_size"].astype("float")),
    x="avg_n_output_tokens",
    y="avg_ws_loss",
    hue="model_size",
    style="is_reasoning",
)
ax.grid(alpha=0.5)
plt.gca().invert_yaxis()
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")