In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import logging

logger = logging.getLogger(__name__)

In [3]:
from lewidi_lib import load_dataset, enable_logging

enable_logging()

datasets = ["CSC", "MP"]
splits = ["train", "dev"]
ddf = pd.concat([load_dataset(d, split=s) for d in datasets for s in splits])

In [None]:
import duckdb
from lewidi_lib import assign_col_template_alias, load_preds, process_rdf

con = duckdb.connect()
rdf = load_preds(parquets_dir="../parquets")
rdf = rdf.query("~template_id.isin([0, 1, 4])")
rdf = process_rdf(rdf)

In [None]:
import seaborn as sns

g = sns.relplot(
    data=rdf,
    x="model_size",
    y="is_valid_pred",
    hue="gen_kwargs",
    col="template_alias",
    row="split",
    # row_order=["train", "dev"],
    kind="line",
    style="gen_kwargs",
    marker="o",
    height=3,
    aspect=1.2,
)
g.set_axis_labels("Model Params [B]", "Valid Preds")
for ax in g.axes.flat:
    ax.set_ylim(0, 1.05)
    ax.grid(alpha=0.5)


In [None]:
logger.info(
    "Dropping %d predictions that don't sum to 1", len(rdf.query("~is_valid_pred"))
)
rdf.query("is_valid_pred", inplace=True)

In [7]:
from lewidi_lib import assign_cols_perf_metrics, join_dataset_and_preds

joint_df = join_dataset_and_preds(ddf, rdf).pipe(assign_cols_perf_metrics)

# Baselines

In [None]:
from lewidi_lib import (
    compute_baseline_entropy,
    compute_target_entropy,
    compute_unif_baseline_perf_metrics,
    compute_strong_baselines_perf_metrics,
    compute_average_baseline,
    compute_smoothed_baseline,
    compute_best_wsloss_baseline,
)

average_baseline = compute_average_baseline(rdf)
smoothed_baseline = compute_smoothed_baseline(rdf)
strong_baselines = compute_strong_baselines_perf_metrics()
best_wsloss_baseline = compute_best_wsloss_baseline(joint_df)
unif_baseline_perf_metrics = (
    compute_unif_baseline_perf_metrics(ddf)
    .merge(pd.Series(rdf["template_id"].unique(), name="template_id"), how="cross")
    .pipe(assign_col_template_alias)
)
unif_baseline_entropy = compute_baseline_entropy(datasets)
target_entropy = compute_target_entropy(ddf)

# Is Performance Correlated With Size?

In [None]:
from lewidi_lib import plot_horizontal_lines
import seaborn as sns

cols_ = ["template_id", "model_id", "model_size", "gen_kwargs", "dataset", "ws_loss"]
data_ = (
    pd.concat(
        [
            joint_df[cols_].assign(model_type="Simple"),
            average_baseline[cols_].assign(model_type="Averaging"),
            smoothed_baseline[cols_].assign(model_type="Smoothed"),
            best_wsloss_baseline[cols_].assign(model_type="Best WS Loss"),
        ]
    )
    .pipe(assign_col_template_alias)
    .query("gen_kwargs == 'set2'")
)

g = sns.relplot(
    data=data_,
    x="model_size",
    y="ws_loss",
    hue="model_type",
    col="template_alias",
    col_order=sorted(data_["template_alias"].unique()),
    # row="gen_kwargs",
    # row_order=["CSC", "MP"],
    kind="line",
    marker="o",
    # hue="gen_kwargs",
    # style="gen_kwargs",
    height=3,
    aspect=1.2,
    facet_kws={"sharey": True},
)
# g.set(ylim=(0, None))
g.set_axis_labels("Model Params [B]", "Wasserstein Distance")
plot_horizontal_lines(
    g,
    unif_baseline_perf_metrics,
    label="Uniform Baseline",
    color="blue",
    data_col="ws_loss",
)
plot_horizontal_lines(
    g, strong_baselines, label="Gemini 2.5 Pro", color="red", data_col="ws_loss"
)

In [None]:
ax = sns.barplot(
    data_.query("model_size == '32.0'").sort_values("template_alias"),
    x="template_alias",
    y="ws_loss",
    hue="model_type",
)
ax.grid(alpha=0.5)
ax.legend(loc="upper right")

# Is performance correlated with avg entropy?

In [None]:
ent_data_ = pd.concat(
    [
        joint_df.assign(model_type="Simple"),
        average_baseline.assign(model_type="Averaging"),
        smoothed_baseline.assign(model_type="Smoothed"),
        best_wsloss_baseline.assign(model_type="Best WS Loss"),
    ]
)

g = sns.relplot(
    data=ent_data_,
    x="model_size",
    y="pred_entropy",
    hue="model_type",
    col="template_id",
    # col_order=["train", "dev"],
    row="gen_kwargs",
    kind="line",
    style="model_type",
    marker="o",
    height=2.5,
    aspect=1.2,
    facet_kws={"sharey": True, "sharex": True},
)
for ax in g.axes.flat:
    ax.grid(alpha=0.5)
plot_horizontal_lines(
    g,
    unif_baseline_entropy,
    label="Uniform Entropy",
    color="blue",
    data_col="entropy",
    dataset="CSC",
)
plot_horizontal_lines(
    g,
    target_entropy,
    label="Target Entropy",
    color="green",
    data_col="entropy",
    dataset="CSC",
    split="train",
)

ent_df = joint_df.groupby(
    ["model_size", "gen_kwargs", "dataset", "split"], as_index=False
).agg(
    avg_entropy=("pred_entropy", "mean"),
    avg_ws_loss=("ws_loss", "mean"),
)


g = sns.relplot(
    ent_df,
    x="avg_entropy",
    y="avg_ws_loss",
    hue="model_size",
    style="gen_kwargs",
    col="split",
    # col_order=["train", "dev"],
    row="dataset",
    # row_order=["CSC", "MP"],
    kind="scatter",
    height=2.5,
    aspect=1.2,
    facet_kws={"sharey": False, "sharex": False},
    palette="viridis",
)
for ax in g.axes.flat:
    ax.grid(alpha=0.5)
plot_horizontal_lines(
    g, unif_baseline_perf_metrics, label="Uniform Baseline", color="blue"
)
plot_horizontal_lines(g, strong_baselines, label="Gemini 2.5 Pro", color="red")
