In [None]:
%load_ext autoreload
%autoreload 2
import seaborn as sns
sns.set_context("talk")

In [None]:
from pathlib import Path
from lewidi_lib import (
    assign_cols_perf_metrics_softlabel,
    enable_logging,
    join_dataset,
    load_preds,
    make_query_from_dict,
    process_rdf,
)
import pandas as pd
import logging

logger = logging.getLogger(__name__)

enable_logging()


# ratings = pd.read_json(
#     "../parquets/reasoning-ratings/template-2-reasoning-judge-responses.jsonl",
#     lines=True,
# )
dataset = "CSC"
pred_model_dir = "Qwen_Qwen3-32B"
root = Path("/mnt/md0/tomasruiz/dss_home/lewidi-data/sbatch/di38bec")
ratings = pd.read_parquet(
    f"{root}/{pred_model_dir}/set2/t31/{dataset}/train/allex_20loops/judge/Qwen/Qwen3-32B/set2/t22/1000exs_10loops/responses.parquet",
    # f"/home/tomasruiz/datasets/dss_home/lewidi-data/sbatch/di38bec/{pred_model_dir}/set2/t31/CSC/allexs_20loops/judge/gemini-2.5-flash/t2/500ex-10loops/responses.jsonl",
)
use_json_ratings = True
if "split" not in ratings.columns:
    ratings = ratings.assign(split="train")
print("len(ratings)=", len(ratings))

# rdf = load_preds(parquets_dir="../parquets")
rdf = load_preds(
    # "/home/tomasruiz/datasets/dss_home/lewidi-data/sbatch/di38bec/tasks_0_cscfull_t31_Qwen_Qwen3-32B_set2/preds"
    f"{root}/{pred_model_dir}/set2/t31/{dataset}/train/allex_20loops/preds"
)
rdf.drop_duplicates(inplace=True)


def preprocess(rdf: pd.DataFrame, model_id="Qwen/Qwen3-32B") -> pd.DataFrame:
    metadata = {
        "template_id": 31,
        "model_id": model_id,
        "gen_kwargs": "set2",
        "dataset": dataset,
        "judge_model_id": "gemini-2.5-pro",
    }
    query = make_query_from_dict(metadata, rdf.columns)
    rdf = rdf.query(query)
    rdf = rdf.query("run_idx.isin([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])")
    rdf = process_rdf(rdf)
    rdf = join_dataset(rdf)
    rdf = assign_cols_perf_metrics_softlabel(rdf)
    return rdf


rdf = preprocess(rdf, model_id=pred_model_dir.replace("_", "/"))

In [None]:
from lewidi_lib import (
    assign_col_score_from_scalar,
    assing_col_score_from_json,
    create_rating_matrix,
    discard_failed_rows,
    discard_na_response_rows,
)

ratings = discard_failed_rows(ratings)
ratings = discard_na_response_rows(ratings)

if use_json_ratings:
    ratings = assing_col_score_from_json(ratings)
else:
    ratings = assign_col_score_from_scalar(ratings)

In [None]:
join_cols = ["dataset", "dataset_idx", "run_idx"]  # expand when more cols!
ratings_cols = [
    "response",
    # "prompt",
    "step_ratings",
    "score",
    "reasoning",
    "judge_model_id",
    "dataset",
    "split",
    "dataset_idx",
    "run_idx",
]


def join_ratings(rdf: pd.DataFrame, ratings: pd.DataFrame, ratings_cols=ratings_cols):
    return ratings[ratings_cols].merge(
        rdf, on=join_cols, how="inner", suffixes=("_judge", "")
    )


joint = join_ratings(rdf, ratings, ratings_cols=ratings.columns)
joint = joint.assign(row_idx=range(len(joint)))
joint = joint.assign(
    score_rank=joint.groupby("dataset_idx")["score"].rank(method="first").astype("int")
)
# assert len(joint) == len(ratings), (len(joint), len(ratings))

In [None]:
from lewidi_lib import compute_n_steps_equality

fraction_n_steps_equal = compute_n_steps_equality(joint, step_split_type="linebreaks")
print(fraction_n_steps_equal)
assert fraction_n_steps_equal > 0.85, fraction_n_steps_equal

In [None]:
# if use_json_ratings:
all_best_rows = create_rating_matrix(ratings)
join_ratings(
    rdf, all_best_rows, ratings_cols=[*ratings_cols, "rating_type", "reduction"]
).groupby(["reduction", "rating_type"]).agg(
    score=("score", "mean"),
    ws_loss=("ws_loss", "mean"),
    # pred_entropy=("pred_entropy", "mean"),
).round(2)

In [None]:
corrs = []
for dataset_idx, group in joint.groupby("dataset_idx"):
    corrs.append(group[["score", "ws_loss"]].corr()["score"]["ws_loss"])
pd.Series(corrs).describe()

In [None]:
joint[["score", "ws_loss"]].corr()

In [None]:
best_by_judge_idxs = joint.groupby("dataset_idx")["score"].idxmax()
best_by_judge = joint.loc[best_by_judge_idxs][
    [
        "dataset_idx",
        "score",
        "tgt_has_holes",
        "ws_loss",
        "pred_entropy",
        "target_entropy",
    ]
]

import seaborn as sns

joint.groupby("score_rank")[["score", "ws_loss"]].mean()
sns.lineplot(joint, x="score_rank", y="ws_loss")

In [None]:
from lewidi_lib import compute_average_baseline

bon_oracle = joint.loc[joint.groupby("dataset_idx")["ws_loss"].idxmin()]
worst_by_judge_idxs = joint.groupby("dataset_idx")["score"].idxmin()
worst_by_judge = joint.loc[worst_by_judge_idxs]
discard_worst = joint.query("~row_idx.isin(@worst_by_judge['row_idx'])")
(worst_by_judge["ws_loss"].mean().round(3), discard_worst["ws_loss"].mean().round(3))

In [None]:
joint.loc[best_by_judge_idxs, "judge_rating"] = "best"
joint.loc[worst_by_judge_idxs, "judge_rating"] = "worst"
joint = joint.assign(judge_rating=joint["judge_rating"].fillna("in-between"))

import seaborn as sns

col_wrap = 6
n_cols = 4
dataset_idxs = joint["dataset_idx"].unique()[: col_wrap * n_cols]

fgrid = sns.FacetGrid(
    joint.query("dataset_idx.isin(@dataset_idxs)"),
    col="dataset_idx",
    col_wrap=col_wrap,
    height=3,
    aspect=1,
)
fgrid.map_dataframe(
    sns.scatterplot, x="score", y="ws_loss", hue="best_by_judge", alpha=0.5
)
fgrid.map_dataframe(
    sns.regplot, x="score", y="ws_loss", scatter=False, color="steelblue"
)

for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5)
fgrid.add_legend(title="Best by judge")

In [None]:
import seaborn as sns

# Using JointGrid directly for more control
fgrid = sns.JointGrid(data=joint, x="score", y="ws_loss")
fgrid.plot_joint(sns.scatterplot, data=joint, alpha=0.5, hue="judge_rating")
fgrid.plot_joint(sns.regplot, scatter=False)  # Add regression line
fgrid.plot_marginals(
    sns.histplot, data=joint, hue="judge_rating", stat="density", common_norm=False
)
fgrid.ax_joint.legend(bbox_to_anchor=(1.2, 1), loc="upper left", title="Judge Rating")
fgrid.ax_joint.grid(alpha=0.5)

import nltk

joint = joint.assign(reasoning_len_chars=joint["reasoning"].apply(len))
most_cot_chars = joint.loc[joint.groupby("dataset_idx")["reasoning_len_chars"].idxmax()]

joint = joint.assign(
    reasoning_len_steps=joint["reasoning"].apply(lambda r: len(nltk.sent_tokenize(r)))
)
most_cot_steps = joint.loc[joint.groupby("dataset_idx")["reasoning_len_steps"].idxmax()]

In [None]:
from lewidi_lib import (
    agg_perf_metrics,
    compute_average_baseline_and_assing_perf_metrics,
    process_rdf_and_add_perf_metrics,
)

model_avg_baseline = compute_average_baseline_and_assing_perf_metrics(joint)

In [None]:
joint = joint.assign(
    entropy_rank=joint.groupby("dataset_idx")["pred_entropy"]
    .rank(method="first")
    .astype(int)
)
by_entropy = joint.groupby("entropy_rank", as_index=False)[
    ["ws_loss", "pred_entropy"]
].mean()
by_entropy["type"] = (
    "entropy"  # "entropy r" + (by_entropy["entropy_rank"] - 1).astype(str)
)

# Do we have enough examples to judge?

In [None]:
import numpy as np


subset = rdf.merge(
    joint[["dataset_idx", "run_idx"]], on=["dataset_idx", "run_idx"], how="inner"
)
# assert len(subset) == len(joint)
# assert np.isclose(joint["ws_loss"].mean(), subset["ws_loss"].mean())
joint["ws_loss"].mean().round(3), rdf["ws_loss"].mean().round(3)

In [None]:
from lewidi_lib import bootstrap_avg

df_to_name = [
    ("Simple", joint),
    ("BoN", best_by_judge),
    ("Averaging", model_avg_baseline),
]
res = []
for name, df in df_to_name:
    data = bootstrap_avg(df["ws_loss"]).to_dict()
    data["name"] = name
    res.append(data)

baselines_cis = pd.DataFrame(res)
baselines_cis.round(3).to_csv(f"tables/ws_loss_bootstrap_{dataset}.csv", index=False)
baselines_cis

In [None]:
baselines = pd.concat([df.assign(type=name) for name, df in df_to_name]).assign(
    dataset=dataset
)
fgrid = sns.catplot(
    baselines,
    x="type",
    y="ws_loss",
    kind="bar",
    hue="type",
    col="dataset",
)
fgrid.set_axis_labels("", "Wasserstein Loss")
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5)
fgrid.figure.savefig("imgs/bon-eval/ws_loss_by_method.pdf", bbox_inches="tight")

In [None]:
cols = ["pred_entropy", "ws_loss"]
loss_vs_entropy = pd.DataFrame(
    {
        "best_by_judge": best_by_judge[cols].mean(),
        "discard_worst_by_judge": discard_worst[cols].mean(),
        "simple (full dataset)": rdf[cols].mean(),
        "simple (judged subset)": joint[cols].mean(),
        "model-avg (judged subset)": model_avg_baseline[cols].mean(),
        "model-avg (full dataset)": compute_average_baseline_and_assing_perf_metrics(
            rdf
        )[cols].mean(),
        "BoN Oracle (judged subset)": bon_oracle[cols].mean(),
        # "most_cot_chars": most_cot_chars[cols].mean(),
        # "most_cot_steps": most_cot_steps[cols].mean(),
    }
).T.reset_index(names="type")
loss_vs_entropy = pd.concat([loss_vs_entropy, by_entropy.drop(columns="entropy_rank")])

In [None]:
data_ = loss_vs_entropy.query("type != 'BoN Oracle'")
grid = sns.JointGrid(data=data_, x="pred_entropy", y="ws_loss")
grid.plot_joint(sns.scatterplot, hue=data_["type"], style=data_["type"])
grid.plot_marginals(sns.histplot, multiple="stack")
grid.ax_joint.legend(bbox_to_anchor=(1.2, 1), loc="upper left")
grid.ax_joint.grid(alpha=0.5)

In [None]:
(
    ratings["dataset_idx"].nunique(),
    ratings["run_idx"].nunique(),
    len(ratings.drop_duplicates(subset=["dataset_idx", "run_idx"])),
)

In [None]:
loss_vs_entropy.head(10).round(3)

In [None]:
rdf["dataset_idx"].nunique(), rdf["run_idx"].nunique()

# Estimating Mean Variability

In [None]:
from lewidi_lib import bootstrap_avg

losses_by_example = joint.groupby("dataset_idx", as_index=False)["ws_loss"].mean()
bootstrap_avg(losses_by_example["ws_loss"])

In [None]:
bootstrap_avg(best_by_judge["ws_loss"])

full_losses_by_example = (
    rdf.query("run_idx < 10").groupby("dataset_idx", as_index=False)["ws_loss"].mean()
)
low, mean, high = bootstrap_avg(full_losses_by_example["ws_loss"])
print(f"Mean: {mean:.3f}, 95% CI: {low:.3f} - {high:.3f}")

from lewidi_lib import get_stable_random_subset

all_res = []
n_samples = [100, 300, 500, 1000, 2000, 3000, 4000, 5000]
for n in n_samples:
    ds_idxs = get_stable_random_subset(rdf["dataset_idx"], n)
    subset = rdf.query("dataset_idx in @ds_idxs")
    subset_losses_by_example = subset.groupby("dataset_idx", as_index=False)["ws_loss"].mean()
    res = bootstrap_avg(subset_losses_by_example["ws_loss"])
    all_res.append(res)
    low, mean, high = res
    print(f"#Examples: {n}, Mean: {mean:.2f}, 95% CI: {low:.2f} - {high:.2f}, CI width: {high - low:.2f}")