In [None]:
# %load_ext cudf.pandas
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import logging

logger = logging.getLogger(__name__)

In [None]:
from lewidi_lib import load_dataset, enable_logging

enable_logging()

datasets = ["CSC", "MP", "Paraphrase", "VariErrNLI"]
splits = ["train"]  # , "dev"]
template_id = "31"
run_name = "allex_10loops"
ddf = pd.concat([load_dataset(d, split=s) for d in datasets for s in splits])

In [None]:
import duckdb
from lewidi_lib import (
    assign_col_template_alias,
    process_rdf,
    list_preds,
    get_stable_random_subset,
)
import pandas as pd


preds_files_df = list_preds().query(
    f"split == 'train' and run_name == '{run_name}' and template_id == '{template_id}'"
)
preds_files_df

In [None]:
files = preds_files_df["preds_file"].tolist()
rdf = duckdb.sql(f"SELECT * FROM read_parquet({[str(f) for f in files]})").df()
print(len(rdf))
indxs = (
    rdf.groupby("dataset", as_index=False)["dataset_idx"]
    .apply(get_stable_random_subset, n=2000)
    .explode("dataset_idx")
)
rdf = rdf.merge(indxs, on=["dataset", "dataset_idx"], how="inner")
print(len(rdf))

In [None]:
rdf = process_rdf(rdf)

In [None]:
rdf["dataset"].unique()

In [None]:
import seaborn as sns

sns.set_context(context="talk")

In [None]:
from pathlib import Path

g = sns.relplot(
    data=rdf,
    x="model_size",
    y="is_valid_pred",
    hue="dataset",
    style="dataset",
    markers=["o", "s", "D", "P"],
    # row_order=["train", "dev"],
    kind="line",
    marker="o",
    height=3.5,
    aspect=1.5,
)
g.set_axis_labels("Model Params [B]", "Valid Soft-Labels")
for ax in g.axes.flat:
    ax.set_ylim(0, 1.05)
    ax.grid(alpha=0.5)

tgt_path = Path("./imgs/soft-label/valid_preds_by_model_size.pdf")
tgt_path.parent.mkdir(parents=True, exist_ok=True)
g.figure.savefig(tgt_path, bbox_inches="tight")

In [None]:
logger.info(
    "Dropping %d predictions that don't sum to 1", len(rdf.query("~is_valid_pred"))
)
rdf.query("is_valid_pred", inplace=True)

In [None]:
from lewidi_lib import assign_cols_perf_metrics_softlabel, join_dataset_and_preds

joint_df = join_dataset_and_preds(ddf, rdf).pipe(assign_cols_perf_metrics_softlabel)

# Baselines

In [None]:
from lewidi_lib import (
    compute_average_baseline_and_assing_perf_metrics,
    compute_baseline_entropy,
    compute_majority_baseline,
    compute_target_entropy,
    compute_unif_baseline_perf_metrics,
    agg_perf_metrics,
    compute_smoothed_baseline,
    compute_best_wsloss_baseline,
)


majority_baseline = compute_majority_baseline(ddf)
agg_majority_baseline = agg_perf_metrics(majority_baseline)
average_baseline = compute_average_baseline_and_assing_perf_metrics(rdf)
smoothed_baseline = compute_smoothed_baseline(rdf)
best_wsloss_baseline = compute_best_wsloss_baseline(joint_df)
unif_baseline_perf_metrics = compute_unif_baseline_perf_metrics(ddf)
unif_baseline_entropy = compute_baseline_entropy(datasets)
target_entropy = compute_target_entropy(ddf)

# Is Performance Correlated With Size?

In [None]:
from pathlib import Path
from lewidi_lib import plot_horizontal_lines
import seaborn as sns

cols_ = [
    "template_id",
    "model_id",
    "model_size",
    "gen_kwargs",
    "dataset",
    "ws_loss",
    "pred_entropy",
]
data_ = (
    pd.concat(
        [
            joint_df[cols_].assign(Baseline="Simple"),
            average_baseline[cols_].assign(Baseline="Averaging"),
            smoothed_baseline[cols_].assign(Baseline="Smoothing"),
            best_wsloss_baseline[cols_].assign(Baseline="BoN Oracle"),
        ]
    )
    .pipe(assign_col_template_alias)
    .query(f"gen_kwargs == 'set2' and template_id == '{template_id}'")
)

col_wrap = 2
g = sns.relplot(
    data=data_.query("Baseline != 'BoN Oracle'"),
    x="model_size",
    y="ws_loss",
    hue="Baseline",
    style="Baseline",
    markers=["o", "s", "D"],
    # col="template_alias",
    # col_order=sorted(data_["template_alias"].unique()),
    col="dataset",
    col_order=datasets,
    col_wrap=col_wrap,
    kind="line",
    # hue="gen_kwargs",
    # style="gen_kwargs",
    height=4.0,
    aspect=1.2,
    facet_kws={"sharey": False},
)
# g.set(ylim=(0, None))
g.set_axis_labels("Model Params [B]", "Wasserstein Distance")
plot_horizontal_lines(
    g,
    unif_baseline_perf_metrics,
    label="Uniform Baseline",
    color="blue",
    data_col="ws_loss",
    hpos="right",
    vpos="top",
)
plot_horizontal_lines(
    g,
    agg_majority_baseline,
    label="Majority Baseline",
    color="green",
    data_col="ws_loss",
    hpos="right",
)
tgt_path = Path("./imgs/soft-label/baselines/ws_loss_vs_model_size.pdf")
tgt_path.parent.mkdir(parents=True, exist_ok=True)
g.figure.savefig(tgt_path, bbox_inches="tight")

In [None]:
data_32b = data_.query("model_id.str.contains('32B')")
method_order = ["Simple", "Smoothing", "Averaging", "BoN Oracle"]
ws_loss_32b = duckdb.sql(
    "PIVOT data_32b ON dataset using mean(ws_loss) GROUP BY Baseline"
).df()
ws_loss_32b = ws_loss_32b.set_index("Baseline").loc[method_order].reset_index()
ws_loss_32b.round(3).to_csv("tables/32b_ws_loss.csv", index=False)

# Is performance correlated with avg entropy?

In [None]:
ent_data_ = pd.concat(
    [
        joint_df.assign(Baseline="Simple"),
        average_baseline.assign(Baseline="Averaging"),
        smoothed_baseline.assign(Baseline="Smoothing"),
        # best_wsloss_baseline.assign(Baseline="BoN Oracle"),
    ]
).query(f"template_id == '{template_id}'")

g = sns.relplot(
    data=ent_data_,
    x="model_size",
    y="pred_entropy",
    hue="Baseline",
    style="Baseline",
    markers=["o", "s", "D"],
    # col="template_alias",
    # col_order=sorted(ent_data_["template_alias"].unique()),
    col="dataset",
    col_order=datasets,
    col_wrap=col_wrap,
    kind="line",
    marker="o",
    height=4.0,
    aspect=1.2,
    facet_kws={"sharey": False, "sharex": True},
)
g.set_axis_labels("Model Params [B]", "Avg. Entropy")
for ax in g.axes.flat:
    ax.grid(alpha=0.5)
plot_horizontal_lines(
    g,
    unif_baseline_entropy,
    label="Uniform Entropy",
    color="blue",
    data_col="entropy",
)
plot_horizontal_lines(
    g,
    target_entropy,
    label="Human Entropy",
    color="green",
    data_col="entropy",
    pos="right",
)
tgt_path = Path("./imgs/soft-label/baselines/entropy_vs_model_size.pdf")
tgt_path.parent.mkdir(parents=True, exist_ok=True)
g.figure.savefig(tgt_path, bbox_inches="tight")

ent_df = joint_df.groupby(
    ["model_size", "gen_kwargs", "dataset", "split"], as_index=False
).agg(
    avg_entropy=("pred_entropy", "mean"),
    avg_ws_loss=("ws_loss", "mean"),
)


g = sns.relplot(
    ent_df,
    x="avg_entropy",
    y="avg_ws_loss",
    hue="model_size",
    style="gen_kwargs",
    col="split",
    # col_order=["train", "dev"],
    row="dataset",
    # row_order=["CSC", "MP"],
    kind="scatter",
    height=2.5,
    aspect=1.2,
    facet_kws={"sharey": False, "sharex": False},
    palette="viridis",
)
for ax in g.axes.flat:
    ax.grid(alpha=0.5)
plot_horizontal_lines(
    g, unif_baseline_perf_metrics, label="Uniform Baseline", color="blue"
)
plot_horizontal_lines(g, strong_baselines, label="Gemini 2.5 Pro", color="red")
