In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from lewidi_lib import (
    enable_logging,
    load_preds,
    process_rdf_and_add_perf_metrics,
)

enable_logging()

folder = "/home/tomasruiz/datasets/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-32B/set2/t31/CSC/train/allexs_20loops/preds"
rdf = load_preds(folder)
rdf = rdf.query("run_idx <= 9")
rdf = process_rdf_and_add_perf_metrics(rdf, discard_invalid_pred=True)
rdf.drop_duplicates(subset=["dataset_idx", "run_idx"], inplace=True)

In [None]:
len(rdf)

In [26]:
from itertools import combinations

import numpy as np
import scipy
from lewidi_lib import as_np


def avg_pairwise_ws_loss(preds) -> float:
    np_preds = as_np(preds)
    _, dim = np_preds.shape
    space = np.arange(dim)
    dists = []
    for p1, p2 in combinations(np_preds, r=2):
        d = scipy.stats.wasserstein_distance(space, space, p1, p2)
        dists.append(d)
    avg = np.mean(dists)
    return avg


answer_diversity = rdf.groupby("dataset_idx", as_index=False).agg(
    avg_pairwise_ws_loss=("pred", avg_pairwise_ws_loss),
    avg_ws_loss=("ws_loss", "mean"),
)

In [27]:
from lewidi_lib import compute_average_baseline_and_assing_perf_metrics


model_avg_rdf = compute_average_baseline_and_assing_perf_metrics(rdf)
model_avg_rdf = model_avg_rdf[["dataset_idx", "ws_loss"]].rename(
    columns={"ws_loss": "model_avg_ws_loss"}
)

In [28]:
import pandas as pd

answer_diversity = answer_diversity.assign(
    diversity=pd.qcut(
        answer_diversity["avg_pairwise_ws_loss"],
        5,
        labels=["Q1", "Q2", "Q3", "Q4", "Q5"],
    )
)
joint = answer_diversity.merge(model_avg_rdf, on="dataset_idx", how="left")
joint = joint.assign(improvement=lambda df: df["avg_ws_loss"] - df["model_avg_ws_loss"])


In [None]:
import seaborn as sns

sns.set_context("talk")

grid = sns.JointGrid(data=joint, x="avg_pairwise_ws_loss", y="improvement")
grid.plot_joint(sns.scatterplot, data=joint, alpha=0.2)
grid.plot_joint(sns.regplot, scatter=False, lowess=True)
grid.plot_marginals(sns.histplot, data=joint)
grid.ax_joint.grid(alpha=0.5)
grid.set_axis_labels(xlabel="Answer Diversity", ylabel="Model Averaging Improvement")

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(figsize=(12, 4), ncols=2, gridspec_kw={'wspace': 0.3})
ax1, ax2 = axs

sns.boxplot(
    rdf.merge(answer_diversity[["dataset_idx", "diversity"]], on="dataset_idx", how="left"),
    x="diversity",
    y="ws_loss",
    showfliers=False,
    ax=ax1,
    whis=(5, 95),
)
ax1.set_ylabel("Simple")

sns.boxplot(
    joint,
    x="diversity",
    y="model_avg_ws_loss",
    showfliers=False,
    ax=ax2,
    whis=(5, 95),
)
ax2.set_ylabel("Model Averaging")

for ax in axs:
    ax.grid(alpha=0.5, axis="y")
    ax.set_ylim(-.1, 3)

# What is the Worst Case Performance By Diversity?

In [41]:
oracle = rdf.loc[rdf.groupby("dataset_idx")["ws_loss"].idxmin()]
oracle = oracle.merge(answer_diversity[["dataset_idx", "diversity"]], on="dataset_idx", how="left")

In [None]:
ax = sns.boxplot(
    oracle,
    x="diversity",
    y="ws_loss",
    showfliers=False,
    whis=(5, 95),
)
ax.grid(alpha=0.5, axis="y")
ax.set_ylim(None, 3)