In [7]:
%load_ext autoreload
%autoreload 2

import os


os.environ["DSS_HOME"] = "/mnt/md0/tomasruiz/dss_home"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import pandas as pd
import logging

logger = logging.getLogger(__name__)

In [9]:
from lewidi_lib import load_dataset, enable_logging

enable_logging()

datasets = ["CSC", "MP", "Paraphrase"]  # VariErrNLI neeeds lots of refactoring
splits = ["train"]  # , "dev"]
ddf = pd.concat([load_dataset(d, split=s) for d in datasets for s in splits])

In [18]:
from lewidi_lib import preds_file

models = [
    "Qwen/Qwen3-0.6B",
    "Qwen/Qwen3-1.7B",
    "Qwen/Qwen3-4B",
    "Qwen/Qwen3-8B",
    "Qwen/Qwen3-14B",
    "Qwen/Qwen3-32B",
]

preds_files = []
for dataset in datasets:
    for model in models:
        file = preds_file(dataset=dataset, split="train", template=31, model_id=model, run_name="allex_10loops")
        preds_files.append((file, file.exists()))

preds_files

[(PosixPath('/mnt/md0/tomasruiz/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-0.6B/set2/t31/CSC/train/allex_10loops/preds/responses.parquet'),
  False),
 (PosixPath('/mnt/md0/tomasruiz/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-1.7B/set2/t31/CSC/train/allex_10loops/preds/responses.parquet'),
  False),
 (PosixPath('/mnt/md0/tomasruiz/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-4B/set2/t31/CSC/train/allex_10loops/preds/responses.parquet'),
  False),
 (PosixPath('/mnt/md0/tomasruiz/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-8B/set2/t31/CSC/train/allex_10loops/preds/responses.parquet'),
  False),
 (PosixPath('/mnt/md0/tomasruiz/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-14B/set2/t31/CSC/train/allex_10loops/preds/responses.parquet'),
  False),
 (PosixPath('/mnt/md0/tomasruiz/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-32B/set2/t31/CSC/train/allex_10loops/preds/responses.parquet'),
  False),
 (PosixPath('/mnt/md0/tomasruiz/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-0.6

In [None]:
import duckdb
from lewidi_lib import assign_col_template_alias, load_preds, process_rdf


def narrow_down_analysis(rdf: pd.DataFrame) -> pd.DataFrame:
    # keep gen_kwargs == 'gemini-defaults'
    rdf = rdf.query(
        "~template_id.isin([0, 1, 4]) and ~gen_kwargs.isin(['set1', 'thinking'])"
    )
    assert len(rdf) > 0
    return rdf


con = duckdb.connect()
rdf = load_preds(parquets_dir="../parquets")
rdf = narrow_down_analysis(rdf)
rdf = process_rdf(rdf)

import seaborn as sns

g = sns.relplot(
    data=rdf,
    x="model_size",
    y="is_valid_pred",
    hue="gen_kwargs",
    col="template_alias",
    row="split",
    # row_order=["train", "dev"],
    kind="line",
    style="gen_kwargs",
    marker="o",
    height=3,
    aspect=1.2,
)
g.set_axis_labels("Model Params [B]", "Valid Preds")
for ax in g.axes.flat:
    ax.set_ylim(0, 1.05)
    ax.grid(alpha=0.5)


In [None]:
logger.info(
    "Dropping %d predictions that don't sum to 1", len(rdf.query("~is_valid_pred"))
)
rdf.query("is_valid_pred", inplace=True)

In [None]:
from lewidi_lib import assign_cols_perf_metrics_softlabel, join_dataset_and_preds

joint_df = join_dataset_and_preds(ddf, rdf).pipe(assign_cols_perf_metrics_softlabel)

# Baselines

In [None]:
from lewidi_lib import (
    compute_average_baseline_and_assing_perf_metrics,
    compute_baseline_entropy,
    compute_majority_baseline,
    compute_target_entropy,
    compute_unif_baseline_perf_metrics,
    agg_perf_metrics,
    compute_smoothed_baseline,
    compute_best_wsloss_baseline,
    process_rdf_and_add_perf_metrics,
)


majority_baseline = compute_majority_baseline(ddf)
agg_majority_baseline = agg_perf_metrics(majority_baseline)
average_baseline = compute_average_baseline_and_assing_perf_metrics(rdf)
smoothed_baseline = compute_smoothed_baseline(rdf)
best_wsloss_baseline = compute_best_wsloss_baseline(joint_df)
unif_baseline_perf_metrics = compute_unif_baseline_perf_metrics(ddf)
unif_baseline_entropy = compute_baseline_entropy(datasets)
target_entropy = compute_target_entropy(ddf)

In [None]:
strong_baselines_raw = (
    load_preds("../parquets/baseline")
    .pipe(narrow_down_analysis)
    .pipe(process_rdf_and_add_perf_metrics)
)
strong_baselines_agg = strong_baselines_raw.pipe(agg_perf_metrics)

# Is Performance Correlated With Size?

In [None]:
from lewidi_lib import plot_horizontal_lines
import seaborn as sns

cols_ = ["template_id", "model_id", "model_size", "gen_kwargs", "dataset", "ws_loss"]
data_ = (
    pd.concat(
        [
            joint_df[cols_].assign(model_type="Simple"),
            average_baseline[cols_].assign(model_type="Averaging"),
            smoothed_baseline[cols_].assign(model_type="Smoothed"),
            best_wsloss_baseline[cols_].assign(model_type="BoN Oracle"),
        ]
    )
    .pipe(assign_col_template_alias)
    .query("gen_kwargs == 'set2'")
)

g = sns.relplot(
    data=data_,
    x="model_size",
    y="ws_loss",
    hue="model_type",
    col="template_alias",
    col_order=sorted(data_["template_alias"].unique()),
    row="dataset",
    row_order=["CSC", "MP", "Paraphrase"],
    kind="line",
    marker="o",
    # hue="gen_kwargs",
    # style="gen_kwargs",
    height=3,
    aspect=1.2,
    facet_kws={"sharey": "row"},
)
# g.set(ylim=(0, None))
g.set_axis_labels("Model Params [B]", "Wasserstein Distance")
plot_horizontal_lines(
    g,
    unif_baseline_perf_metrics,
    label="Uniform Baseline",
    color="blue",
    data_col="ws_loss",
)
plot_horizontal_lines(
    g, strong_baselines_agg, label="Gemini 2.5 Pro", color="red", data_col="ws_loss"
)
plot_horizontal_lines(
    g,
    agg_majority_baseline,
    label="Majority Baseline",
    color="green",
    data_col="ws_loss",
)

# Is performance correlated with avg entropy?

In [None]:
ent_data_ = pd.concat(
    [
        joint_df.assign(model_type="Simple"),
        average_baseline.assign(model_type="Averaging"),
        smoothed_baseline.assign(model_type="Smoothed"),
        best_wsloss_baseline.assign(model_type="Best WS Loss"),
    ]
)

g = sns.relplot(
    data=ent_data_,
    x="model_size",
    y="pred_entropy",
    hue="model_type",
    col="template_alias",
    col_order=sorted(ent_data_["template_alias"].unique()),
    row="dataset",
    row_order=["CSC", "MP"],
    kind="line",
    style="model_type",
    marker="o",
    height=3,
    aspect=1.2,
    facet_kws={"sharey": "row", "sharex": True},
)
for ax in g.axes.flat:
    ax.grid(alpha=0.5)
plot_horizontal_lines(
    g,
    unif_baseline_entropy,
    label="Uniform Entropy",
    color="blue",
    data_col="entropy",
    dataset="CSC",
)
plot_horizontal_lines(
    g,
    target_entropy,
    label="Target Entropy",
    color="green",
    data_col="entropy",
    dataset="CSC",
    split="train",
)

ent_df = joint_df.groupby(
    ["model_size", "gen_kwargs", "dataset", "split"], as_index=False
).agg(
    avg_entropy=("pred_entropy", "mean"),
    avg_ws_loss=("ws_loss", "mean"),
)


g = sns.relplot(
    ent_df,
    x="avg_entropy",
    y="avg_ws_loss",
    hue="model_size",
    style="gen_kwargs",
    col="split",
    # col_order=["train", "dev"],
    row="dataset",
    # row_order=["CSC", "MP"],
    kind="scatter",
    height=2.5,
    aspect=1.2,
    facet_kws={"sharey": False, "sharex": False},
    palette="viridis",
)
for ax in g.axes.flat:
    ax.grid(alpha=0.5)
plot_horizontal_lines(
    g, unif_baseline_perf_metrics, label="Uniform Baseline", color="blue"
)
plot_horizontal_lines(g, strong_baselines, label="Gemini 2.5 Pro", color="red")
