In [None]:
%load_ext autoreload
%autoreload 2
import seaborn as sns

sns.set_context("talk")

In [None]:
import duckdb
from lewidi_lib import list_preds


files = list_preds().query(
    "run_name == '1000ex_10loops' and model_id == 'Qwen/Qwen3-32B' and template_id == '60' and split == 'train'"
)
rdf = duckdb.sql(
    f"SELECT * FROM read_parquet({[str(f) for f in files.preds_file]}, union_by_name=True)"
).df()

In [None]:
from lewidi_lib import (
    assign_cols_perf_metrics_softlabel,
    enable_logging,
    join_dataset,
    process_rdf,
)

enable_logging()
rdf = rdf.query("run_idx <= 9")
rdf = process_rdf(rdf, response_contains_steps=True, discard_invalid_pred=True)
rdf = join_dataset(rdf)
rdf = assign_cols_perf_metrics_softlabel(rdf)
rdf.drop_duplicates(subset=["dataset", "dataset_idx", "run_idx"], inplace=True)

In [None]:
len(rdf)

In [None]:
from lewidi_lib import avg_pairwise_ws_loss

answer_diversity = rdf.groupby(["dataset", "dataset_idx"], as_index=False).agg(
    avg_pairwise_ws_loss=("pred", avg_pairwise_ws_loss)
)

In [None]:
from lewidi_lib import compute_average_baseline_and_assing_perf_metrics

model_avg_rdf = compute_average_baseline_and_assing_perf_metrics(rdf)
model_avg_rdf = model_avg_rdf[["dataset", "dataset_idx", "ws_loss"]].rename(
    columns={"ws_loss": "model_avg_ws_loss"}
)

In [None]:
from lewidi_lib import assign_cols_perf_metrics_softlabel

oracle = rdf.loc[rdf.groupby(["dataset", "dataset_idx"])["ws_loss"].idxmin()]
oracle = assign_cols_perf_metrics_softlabel(oracle)
oracle = oracle[["dataset", "dataset_idx", "ws_loss"]].rename(
    columns={"ws_loss": "oracle_ws_loss"}
)

In [None]:
from lewidi_lib import diversity

div_col = answer_diversity.groupby("dataset")["avg_pairwise_ws_loss"].transform(
    diversity
)
answer_diversity = answer_diversity.assign(diversity=div_col)

In [None]:
# from lewidi_lib import assign_col_diversity
# answer_diversity = assign_col_diversity(answer_diversity)
joint = rdf.merge(answer_diversity, on=["dataset", "dataset_idx"], how="left")
joint = joint.merge(model_avg_rdf, on=["dataset", "dataset_idx"], how="left")
joint = joint.merge(oracle, on=["dataset", "dataset_idx"], how="left")
joint = joint.assign(
    model_avg_improvement=lambda df: df["ws_loss"] - df["model_avg_ws_loss"],
    oracle_improvement=lambda df: df["ws_loss"] - df["oracle_ws_loss"],
)

In [None]:
from pathlib import Path

tgt_dir = Path("./imgs/diversity")
tgt_dir.mkdir(parents=True, exist_ok=True)

# Is the WS loss correlated with the diversity?
The Wasserstein loss is empirically correlated with the prediction diversity, which means that we can use the prediction diversity (an observable quantity at test-time) as a proxy for the model performance on a problem (an unobservable quantity at test-time).

In [None]:
from lewidi_lib import rename_dataset

datasets = ["CSC", "PAR", "MP", "VEN"]
fgrid = sns.catplot(
    rename_dataset(joint),
    x="diversity",
    y="ws_loss",
    col="dataset",
    col_order=datasets,
    col_wrap=2,
    kind="point",
    # errorbar="sd",
    height=3,
    aspect=1.5,
    sharey=False,
    # showfliers=False,
    # whis=(5, 95),
    # col_wrap=2,
)
fgrid.set_axis_labels("Prediction Diversity", "Wasserstein\nDistance")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")
fgrid.axes[2].set_ylabel("Manhattan\nDistance")
fgrid.figure.savefig(tgt_dir / "ws-loss-vs-diversity.pdf", bbox_inches="tight")

In [None]:
joint.groupby(["dataset", "diversity"], as_index=False)[["ws_loss"]].mean().pivot(
    index="dataset", columns="diversity", values="ws_loss"
).round(2)

In [None]:
improvement_df = joint.melt(
    id_vars=["dataset", "dataset_idx", "diversity"],
    value_vars=["model_avg_improvement", "oracle_improvement"],
    value_name="improvement",
    var_name="type",
)
name_map = {
    "model_avg_improvement": "Model Averaging",
    "oracle_improvement": "BoN Oracle",
}

improvement_df["type"] = improvement_df["type"].map(name_map)
improvement_df = rename_dataset(improvement_df)
fgrid = sns.catplot(
    improvement_df,
    x="diversity",
    y="improvement",
    hue="type",
    col="dataset",
    col_wrap=2,
    col_order=datasets,
    kind="bar",
    errorbar="ci",
    sharey=False,
    height=3.5,
    aspect=1.5,
)
sns.move_legend(
    fgrid,
    loc="lower left",
    bbox_to_anchor=(0.2, 1.0),
    ncol=2,
    title="Method",
)
fgrid.set_axis_labels("Prediction Diversity", "Improvement In\nWasserstein Distance")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")
    ax.set_ylim(0, None)
fgrid.axes[2].set_ylabel("Improvement In\nManhattan Distance")
fgrid.figure.savefig(
    tgt_dir / "improvement-vs-prediction-diversity-quantiles.pdf",
    bbox_inches="tight",
)

In [None]:
# how much value is model_averaging capturing?
ma_value_df = joint.groupby(["dataset", "diversity"], as_index=False)[
    ["model_avg_improvement", "oracle_improvement"]
].mean()
ma_value_df = ma_value_df.assign(
    fraction=lambda df: df["model_avg_improvement"] / df["oracle_improvement"]
)
ma_value_df.pivot(index="dataset", values="fraction", columns="diversity").round(2)

In [None]:
import pandas as pd

data_ = pd.concat(
    [
        rdf[["dataset", "dataset_idx", "ws_loss"]].assign(Method="Simple"),
        model_avg_rdf.rename(columns={"model_avg_ws_loss": "ws_loss"}).assign(
            Method="Model Averaging"
        ),
        oracle.rename(columns={"oracle_ws_loss": "ws_loss"}).assign(
            Method="BoN Oracle"
        ),
    ],
    ignore_index=True,
)
data_ = data_.merge(answer_diversity, on=["dataset", "dataset_idx"], how="left")
data_ = rename_dataset(data_)
fgrid = sns.catplot(
    data_,
    x="diversity",
    # order=["Q1", "Q3", "Q5"],
    y="ws_loss",
    hue="Method",
    hue_order=["Simple", "Model Averaging", "BoN Oracle"],
    kind="point",
    errorbar="ci",
    sharey=False,
    col="dataset",
    col_wrap=2,
    dodge=True,
    capsize=0.2,
    height=4,
    aspect=1.2,
)
fgrid.set_axis_labels("Prediction Diversity", "Wasserstein Distance")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")

In [None]:
fgrid = sns.FacetGrid(
    rename_dataset(answer_diversity),
    col="dataset",
    height=3,
    sharex=False,
    sharey=False,
    col_wrap=2,
    col_order=datasets,
    aspect=1.5,
)
fgrid.map_dataframe(
    sns.histplot, x="avg_pairwise_ws_loss", hue="diversity", multiple="stack"
)
fgrid.set_axis_labels("Prediction Diversity", "Count")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")
fgrid.figure.savefig(tgt_dir / "diversity-distribution.pdf", bbox_inches="tight")