In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Any
import pandas as pd
import numpy as np
import logging

logger = logging.getLogger(__name__)


def soft_label_to_nparray(d: dict | Any) -> np.ndarray:
    if not isinstance(d, dict):
        logger.info("Not a dict: %s", repr(d))
        return pd.NA

    array = np.zeros(6)
    for k, v in d.items():
        try:
            array[int(k) - 1] = v
        except ValueError:
            logger.warning("Invalid key: %s", k)
            return pd.NA
    return array


In [None]:
from lewidi_lib import load_dataset, enable_logging

enable_logging()

ddf = load_dataset(dataset="CSC")
ddf = ddf.assign(request_idx=range(len(ddf)))
ddf["target"] = ddf["soft_label"].apply(soft_label_to_nparray)
print(len(ddf))
ddf.head(2)

In [None]:
import duckdb


con = duckdb.connect()
rdf = con.sql("SELECT * FROM read_parquet('parquets/*.parquet')").df()
rdf["model_size"] = rdf["model_id"].str.extract(r"-(\d+(?:\.\d+)?)B$").astype("float")
are_na = len(rdf.query("response.isna()"))
logger.info("Number of responses that are NA: %d", are_na)
rdf.query("~response.isna()", inplace=True)
rdf["response"] = rdf["response"].str.strip()
rdf.head(2)

In [None]:
import json_repair

rdf["pred"] = rdf["response"].apply(json_repair.loads).apply(soft_label_to_nparray)
logger.info("Dropping %d NA predictions", len(rdf.query("pred.isna()")))
rdf.query("~pred.isna()", inplace=True)

rdf["pred_sum"] = rdf["pred"].apply(lambda x: x.sum())
rdf["is_valid_pred"] = (rdf["pred_sum"] - 1).abs() < 0.01
rdf["reasoning_isnull"] = rdf["reasoning"].isna()

# Add columns indicating if the run has reasoning
reasoning_by_run = rdf.groupby("run_id", as_index=False).agg(
    is_reasoning=("reasoning_isnull", lambda x: ~x.max())
)
rdf = rdf.merge(reasoning_by_run, on="run_id", how="left").drop(
    columns=["reasoning_isnull"]
)
rdf.head(2)

In [None]:
run_info = (
    rdf.groupby(["run_id", "model_id", "model_size"], as_index=False)
    .agg(
        is_valid_pred=("is_valid_pred", "mean"),
        n_output_tokens=("n_output_tokens", "mean"),
    )
    .merge(reasoning_by_run, on="run_id", how="left")
    .sort_values(["model_size", "is_reasoning"])
)
run_info.round(2)

In [None]:
run_info.groupby(["model_id", "model_size", "is_reasoning"]).agg(
    {"is_valid_pred": "mean", "n_output_tokens": "mean"}
).sort_values(["model_size", "is_reasoning"]).round(3)

In [None]:
logger.info(
    "Dropping %d predictions that don't sum to 1", len(rdf.query("~is_valid_pred"))
)
rdf.query("is_valid_pred", inplace=True)

In [None]:
joint_df = pd.merge(
    ddf[["request_idx", "target"]],
    rdf,
    on="request_idx",
)
joint_df.head(2)

In [10]:
import scipy


def l0_loss(tgt: np.ndarray, pred: np.ndarray) -> float:
    return np.abs(tgt - pred).sum()


def ws_loss(tgt: np.ndarray, pred: np.ndarray) -> float:
    """wasserstein distance between two distributions https://stackoverflow.com/a/76061410/5730291"""
    return scipy.stats.wasserstein_distance(range(6), range(6), tgt, pred)


joint_df["l0_loss"] = joint_df.apply(
    lambda row: l0_loss(row["target"], row["pred"]), axis=1
)
joint_df["ws_loss"] = joint_df.apply(
    lambda row: ws_loss(row["target"], row["pred"]), axis=1
)

In [None]:
perf_metrics_df = (
    joint_df.groupby(
        ["run_id", "model_id", "model_size", "is_reasoning"], as_index=False
    )
    .agg(
        avg_ws_loss=("ws_loss", "mean"),
        # std_ws_loss=("ws_loss", "std"),
        avg_n_output_tokens=("n_output_tokens", "mean"),
        # avg_l0_loss=("l0_loss", "mean"),
        # std_l0_loss=("l0_loss", "std"),
    )
    .sort_values(["model_size", "is_reasoning"])
)
perf_metrics_df.round(2)

# Baseline: Uniform Distribution

In [None]:
import seaborn as sns

baseline_pred = np.ones(6) / 6
assert np.isclose(1, baseline_pred.sum())
baseline_ws_losses = ddf["target"].apply(lambda tgt: ws_loss(tgt, pred=baseline_pred))
print("baseline Wasserstein loss:", baseline_ws_losses.mean().round(3))
baseline_l0_losses = ddf["target"].apply(lambda tgt: l0_loss(tgt, pred=baseline_pred))
print("baseline L0 loss:", baseline_l0_losses.mean().round(3))
# ax = sns.histplot(baseline_ws_losses, bins=5, kde=True, stat="density")
# ax.set_xlabel("Wasserstein Distance")
# ax.set_title("Baseline")

In [None]:
import seaborn as sns

perf_metrics_df["model_size"] = perf_metrics_df["model_size"].astype(str)
ax = sns.lineplot(
    perf_metrics_df,
    x="model_size",
    y="avg_ws_loss",
    hue="is_reasoning",
    marker="o",  # Add markers to points
)
ax.set_ylim(0, None)
ax.axhline(
    baseline_ws_losses.mean(), color="red", linestyle="--", label="Baseline"
)  # Add label and dashed line
ax.legend()  # Show legend
ax.grid(alpha=0.5)

In [None]:
import matplotlib.pyplot as plt

ax = sns.scatterplot(
    perf_metrics_df.assign(model_size=perf_metrics_df["model_size"].astype("float")),
    x="avg_n_output_tokens",
    y="avg_ws_loss",
    hue="model_size",
    style="is_reasoning",
)
ax.grid(alpha=0.5)
plt.gca().invert_yaxis()
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")