In [None]:
%load_ext autoreload
%autoreload 2
import seaborn as sns

sns.set_context("talk")

In [None]:
import duckdb
from lewidi_lib import list_preds


files = list_preds().query(
    "run_name == '1000ex_10loops' and model_id == 'Qwen/Qwen3-32B' and template_id == '60' and split == 'train'"
)
rdf = duckdb.sql(
    f"SELECT * FROM read_parquet({[str(f) for f in files.preds_file]}, union_by_name=True)"
).df()

In [None]:
from lewidi_lib import (
    assign_cols_perf_metrics_softlabel,
    enable_logging,
    join_dataset,
    process_rdf,
)

enable_logging()
rdf = rdf.query("run_idx <= 9")
rdf = process_rdf(rdf, response_contains_steps=True, discard_invalid_pred=True)
rdf = join_dataset(rdf)
rdf = assign_cols_perf_metrics_softlabel(rdf)
rdf.drop_duplicates(subset=["dataset", "dataset_idx", "run_idx"], inplace=True)

In [None]:
len(rdf)

In [None]:
from lewidi_lib import avg_pairwise_ws_loss

answer_diversity = rdf.groupby(["dataset", "dataset_idx"], as_index=False).agg(
    avg_pairwise_ws_loss=("pred", avg_pairwise_ws_loss)
)

In [None]:
from lewidi_lib import compute_average_baseline_and_assing_perf_metrics

model_avg_rdf = compute_average_baseline_and_assing_perf_metrics(rdf)
model_avg_rdf = model_avg_rdf[["dataset", "dataset_idx", "ws_loss"]].rename(
    columns={"ws_loss": "model_avg_ws_loss"}
)

In [None]:
from lewidi_lib import assign_cols_perf_metrics_softlabel

oracle = rdf.loc[rdf.groupby(["dataset", "dataset_idx"])["ws_loss"].idxmin()]
oracle = assign_cols_perf_metrics_softlabel(oracle)
oracle = oracle[["dataset", "dataset_idx", "ws_loss"]].rename(
    columns={"ws_loss": "oracle_ws_loss"}
)

In [None]:
from lewidi_lib import diversity

div_col = answer_diversity.groupby("dataset")["avg_pairwise_ws_loss"].transform(
    diversity
)
answer_diversity = answer_diversity.assign(diversity=div_col)

In [None]:
# from lewidi_lib import assign_col_diversity
# answer_diversity = assign_col_diversity(answer_diversity)
joint = rdf.merge(answer_diversity, on=["dataset", "dataset_idx"], how="left")
joint = joint.merge(model_avg_rdf, on=["dataset", "dataset_idx"], how="left")
joint = joint.merge(oracle, on=["dataset", "dataset_idx"], how="left")
joint = joint.assign(
    model_avg_improvement=lambda df: df["ws_loss"] - df["model_avg_ws_loss"],
    oracle_improvement=lambda df: df["ws_loss"] - df["oracle_ws_loss"],
)

In [None]:
from pathlib import Path

tgt_dir = Path("./imgs/diversity")
tgt_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# grid = sns.JointGrid(
#     data=joint.query("dataset == 'CSC'"),
#     x="avg_pairwise_ws_loss",
#     y="model_avg_improvement",
# )
# grid.plot_joint(sns.scatterplot, data=joint, alpha=0.2)
# grid.plot_joint(sns.regplot, scatter=False, lowess=True)
# grid.plot_marginals(sns.histplot, data=joint)
# grid.ax_joint.grid(alpha=0.5)
# grid.set_axis_labels(
#     xlabel="Prediction Diversity", ylabel="Model Averaging Improvement"
# )
# grid.figure.savefig(
#     tgt_dir / "model-avg-improvement-vs-answer-diversity.pdf", bbox_inches="tight"
# )

# Is the WS loss correlated with the diversity?
The Wasserstein loss is empirically correlated with the prediction diversity, which means that we can use the prediction diversity (an observable quantity at test-time) as a proxy for the model performance on a problem (an unobservable quantity at test-time).

In [None]:
fgrid = sns.catplot(
    joint,
    x="diversity",
    y="ws_loss",
    col="dataset",
    col_wrap=2,
    kind="point",
    errorbar="sd",
    height=3,
    aspect=1.5,
    sharey=False,
    # showfliers=False,
    # whis=(5, 95),
    # col_wrap=2,
)
fgrid.set_axis_labels("Prediction Diversity", "Wasserstein\nDistance")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")
fgrid.figure.savefig(tgt_dir / "ws-loss-vs-diversity.pdf", bbox_inches="tight")

In [None]:
joint.groupby(["dataset", "diversity"], as_index=False)[["ws_loss"]].mean().pivot(
    index="dataset", columns="diversity", values="ws_loss"
).round(2)

In [None]:
improvement_df = joint.melt(
    id_vars=["dataset", "dataset_idx", "diversity"],
    value_vars=["model_avg_improvement", "oracle_improvement"],
    value_name="improvement",
    var_name="type",
)
name_map = {
    "model_avg_improvement": "Model Averaging",
    "oracle_improvement": "BoN Oracle",
}

improvement_df["type"] = improvement_df["type"].map(name_map)
fgrid = sns.catplot(
    improvement_df,
    x="diversity",
    y="improvement",
    hue="type",
    col="dataset",
    kind="bar",
    errorbar="ci",
    sharey=False,
)
fgrid.set_axis_labels("Prediction Diversity", "Improvement In\nWasserstein Distance")
fgrid.legend.set(title="Method")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")
    ax.set_ylim(0, None)
fgrid.figure.savefig(
    tgt_dir / "improvement-vs-prediction-diversity-quantiles.pdf",
    bbox_inches="tight",
)

In [None]:
# how much value is model_averaging capturing?
ma_value_df = joint.groupby(["dataset", "diversity"], as_index=False)[
    ["model_avg_improvement", "oracle_improvement"]
].mean()
ma_value_df = ma_value_df.assign(
    fraction=lambda df: df["model_avg_improvement"] / df["oracle_improvement"]
)
ma_value_df.pivot(index="dataset", values="fraction", columns="diversity").round(2)

In [None]:
import pandas as pd

data_ = pd.concat(
    [
        rdf[["dataset", "dataset_idx", "ws_loss"]].assign(method="Simple"),
        model_avg_rdf.rename(columns={"model_avg_ws_loss": "ws_loss"}).assign(
            method="Model Averaging"
        ),
        oracle.rename(columns={"oracle_ws_loss": "ws_loss"}).assign(
            method="BoN Oracle"
        ),
    ],
    ignore_index=True,
)
data_ = data_.merge(answer_diversity, on=["dataset", "dataset_idx"], how="left")
fgrid = sns.catplot(
    data_,
    x="diversity",
    # order=["Q1", "Q3", "Q5"],
    y="ws_loss",
    hue="method",
    hue_order=["Simple", "Model Averaging", "BoN Oracle"],
    kind="point",
    errorbar="ci",
    sharey=False,
    col="dataset",
    col_wrap=2,
    dodge=True,
    capsize=0.2,
    height=4,
    aspect=1.2,
)
fgrid.set_axis_labels("Prediction Diversity", "Wasserstein Distance")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")

In [None]:
fgrid = sns.FacetGrid(
    answer_diversity, col="dataset", height=5, sharex=False, sharey=False
)
fgrid.map_dataframe(
    sns.histplot, x="avg_pairwise_ws_loss", hue="diversity", multiple="stack"
)
fgrid.set_axis_labels("Prediction Diversity", "Count")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")
fgrid.figure.savefig(tgt_dir / "diversity-distribution.pdf", bbox_inches="tight")

# Plots Below are Not Generalized For Multiple Datasets Yet

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(figsize=(12, 4), ncols=2, gridspec_kw={"wspace": 0.3})
ax1, ax2 = axs

sns.boxplot(
    rdf.merge(
        answer_diversity[["dataset_idx", "diversity"]], on="dataset_idx", how="left"
    ),
    x="diversity",
    y="ws_loss",
    showfliers=False,
    ax=ax1,
    whis=(5, 95),
)
ax1.set_ylabel("Simple")

sns.boxplot(
    joint,
    x="diversity",
    y="model_avg_ws_loss",
    showfliers=False,
    ax=ax2,
    whis=(5, 95),
)
ax2.set_ylabel("Model Averaging")

for ax in axs:
    ax.grid(alpha=0.5, axis="y")
    ax.set_ylim(-0.1, 3)

# What is the Worst Case Performance By Diversity?

In [None]:
ax = sns.boxplot(
    oracle,
    x="diversity",
    y="ws_loss",
    showfliers=False,
    whis=(5, 95),
)
ax.grid(alpha=0.5, axis="y")
ax.set_ylim(None, 3)
ax.set_title("BoN Oracle WS Loss")

# What is the Range For Improvement from repeated Sampling?

In [None]:
wsloss_improv = rdf[["dataset_idx", "ws_loss"]].merge(
    oracle.rename(columns={"ws_loss": "ws_loss_best"})[["dataset_idx", "ws_loss_best"]],
    on="dataset_idx",
    how="left",
)
wsloss_improv = wsloss_improv.assign(
    improvement=lambda df: df["ws_loss"] - df["ws_loss_best"]
)
wsloss_improv = wsloss_improv.merge(
    answer_diversity[["dataset_idx", "diversity"]], on="dataset_idx", how="left"
)

fig, axs = plt.subplots(figsize=(12, 4), ncols=2, gridspec_kw={"wspace": 0.3})
ax1, ax2 = axs

sns.boxplot(
    joint,
    x="diversity",
    y="improvement",
    showfliers=False,
    whis=(5, 95),
    ax=ax1,
)
ax1.set_title("Model Averaging")

sns.boxplot(
    wsloss_improv,
    x="diversity",
    y="improvement",
    showfliers=False,
    whis=(5, 95),
    ax=ax2,
)
ax2.set_title("BoN Oracle")

for ax in axs:
    ax.grid(alpha=0.5, axis="y")
    ax.set_ylim(-0.1, 2.3)
    ax.set_xlabel("Prediction Diversity")
    ax.set_ylabel("Improvement")