In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import duckdb
from lewidi_lib import list_preds


files = list_preds().query(
    "run_name == 'allex_10loops' and model_id == 'Qwen/Qwen3-32B' and template_id == '31' and split == 'train'"
)
rdf = duckdb.sql(
    f"SELECT * FROM read_parquet({[str(f) for f in files.preds_file]})"
).df()

In [None]:
from lewidi_lib import enable_logging, process_rdf_and_add_perf_metrics

enable_logging()
rdf = rdf.query("run_idx <= 9")
rdf = process_rdf_and_add_perf_metrics(rdf, discard_invalid_pred=True)
rdf.drop_duplicates(subset=["dataset", "dataset_idx", "run_idx"], inplace=True)

In [None]:
len(rdf)

In [None]:
from lewidi_lib import avg_pairwise_ws_loss

answer_diversity = rdf.groupby(["dataset", "dataset_idx"], as_index=False).agg(
    avg_pairwise_ws_loss=("pred", avg_pairwise_ws_loss),
    avg_ws_loss=("ws_loss", "mean"),
)

In [None]:
from lewidi_lib import compute_average_baseline_and_assing_perf_metrics

model_avg_rdf = compute_average_baseline_and_assing_perf_metrics(rdf)
model_avg_rdf = model_avg_rdf[["dataset", "dataset_idx", "ws_loss"]].rename(
    columns={"ws_loss": "model_avg_ws_loss"}
)

In [None]:
from lewidi_lib import assign_cols_perf_metrics_softlabel

oracle = rdf.loc[rdf.groupby(["dataset", "dataset_idx"])["ws_loss"].idxmin()]
oracle = assign_cols_perf_metrics_softlabel(oracle)
oracle = oracle[["dataset", "dataset_idx", "ws_loss"]].rename(
    columns={"ws_loss": "oracle_ws_loss"}
)

In [None]:
from lewidi_lib import diversity

div_col = answer_diversity.groupby("dataset")["avg_pairwise_ws_loss"].transform(
    diversity
)
answer_diversity = answer_diversity.assign(diversity=div_col)

In [None]:
# from lewidi_lib import assign_col_diversity
# answer_diversity = assign_col_diversity(answer_diversity)
joint = answer_diversity.merge(model_avg_rdf, on=["dataset", "dataset_idx"], how="left")
joint = joint.merge(oracle, on=["dataset", "dataset_idx"], how="left")
joint = joint.assign(
    model_avg_improvement=lambda df: df["avg_ws_loss"] - df["model_avg_ws_loss"],
    oracle_improvement=lambda df: df["avg_ws_loss"] - df["oracle_ws_loss"],
)

In [None]:
from pathlib import Path
import seaborn as sns

sns.set_context("talk")

grid = sns.JointGrid(
    data=joint.query("dataset == 'CSC'"),
    x="avg_pairwise_ws_loss",
    y="model_avg_improvement",
)
grid.plot_joint(sns.scatterplot, data=joint, alpha=0.2)
grid.plot_joint(sns.regplot, scatter=False, lowess=True)
grid.plot_marginals(sns.histplot, data=joint)
grid.ax_joint.grid(alpha=0.5)
grid.set_axis_labels(
    xlabel="Prediction Diversity", ylabel="Model Averaging Improvement"
)
tgt_dir = Path("./imgs/diversity")
tgt_dir.mkdir(parents=True, exist_ok=True)
grid.figure.savefig(
    tgt_dir / "model-avg-improvement-vs-answer-diversity.pdf", bbox_inches="tight"
)

In [None]:
improvement_df = joint.melt(
    id_vars=["dataset", "dataset_idx", "diversity"],
    value_vars=["model_avg_improvement", "oracle_improvement"],
    value_name="improvement",
    var_name="type",
)
name_map = {
    "model_avg_improvement": "Model Averaging",
    "oracle_improvement": "BoN Oracle",
}

improvement_df["type"] = improvement_df["type"].map(name_map)
fgrid = sns.catplot(
    improvement_df,
    x="diversity",
    y="improvement",
    hue="type",
    col="dataset",
    kind="bar",
    sharey=False,
)
fgrid.set_axis_labels("Prediction Diversity", "Improvement In\nWasserstein Distance")
fgrid.legend.set(title="Method")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")
fgrid.figure.savefig(
    tgt_dir / "improvement-vs-prediction-diversity-quantiles.pdf",
    bbox_inches="tight",
)

In [None]:
# sns.histplot(joint.query("dataset == 'CSC'"), x="avg_pairwise_ws_loss", hue="diversity", multiple="stack")
fgrid = sns.FacetGrid(joint, col="dataset", height=5, sharex=False, sharey=False)
fgrid.map_dataframe(
    sns.histplot, x="avg_pairwise_ws_loss", hue="diversity", multiple="stack"
)
fgrid.set_axis_labels("Prediction Diversity", "Count")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5, axis="y")
fgrid.figure.savefig(tgt_dir / "diversity-distribution.pdf", bbox_inches="tight")

# Plots Below are Not Generalized For Multiple Datasets Yet

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(figsize=(12, 4), ncols=2, gridspec_kw={"wspace": 0.3})
ax1, ax2 = axs

sns.boxplot(
    rdf.merge(
        answer_diversity[["dataset_idx", "diversity"]], on="dataset_idx", how="left"
    ),
    x="diversity",
    y="ws_loss",
    showfliers=False,
    ax=ax1,
    whis=(5, 95),
)
ax1.set_ylabel("Simple")

sns.boxplot(
    joint,
    x="diversity",
    y="model_avg_ws_loss",
    showfliers=False,
    ax=ax2,
    whis=(5, 95),
)
ax2.set_ylabel("Model Averaging")

for ax in axs:
    ax.grid(alpha=0.5, axis="y")
    ax.set_ylim(-0.1, 3)

# What is the Worst Case Performance By Diversity?

In [None]:
ax = sns.boxplot(
    oracle,
    x="diversity",
    y="ws_loss",
    showfliers=False,
    whis=(5, 95),
)
ax.grid(alpha=0.5, axis="y")
ax.set_ylim(None, 3)
ax.set_title("BoN Oracle WS Loss")

# What is the Range For Improvement from repeated Sampling?

In [None]:
wsloss_improv = rdf[["dataset_idx", "ws_loss"]].merge(
    oracle.rename(columns={"ws_loss": "ws_loss_best"})[["dataset_idx", "ws_loss_best"]],
    on="dataset_idx",
    how="left",
)
wsloss_improv = wsloss_improv.assign(
    improvement=lambda df: df["ws_loss"] - df["ws_loss_best"]
)
wsloss_improv = wsloss_improv.merge(
    answer_diversity[["dataset_idx", "diversity"]], on="dataset_idx", how="left"
)

fig, axs = plt.subplots(figsize=(12, 4), ncols=2, gridspec_kw={"wspace": 0.3})
ax1, ax2 = axs

sns.boxplot(
    joint,
    x="diversity",
    y="improvement",
    showfliers=False,
    whis=(5, 95),
    ax=ax1,
)
ax1.set_title("Model Averaging")

sns.boxplot(
    wsloss_improv,
    x="diversity",
    y="improvement",
    showfliers=False,
    whis=(5, 95),
    ax=ax2,
)
ax2.set_title("BoN Oracle")

for ax in axs:
    ax.grid(alpha=0.5, axis="y")
    ax.set_ylim(-0.1, 2.3)
    ax.set_xlabel("Prediction Diversity")
    ax.set_ylabel("Improvement")