In [22]:
%load_ext autoreload
%autoreload 2

In [23]:
import pandas as pd
import logging

logger = logging.getLogger(__name__)

In [None]:
from lewidi_lib import load_dataset, enable_logging

enable_logging()

datasets = ["CSC", "MP"]
splits = ["train", "dev"]
ddf = pd.concat([load_dataset(d, split=s) for d in datasets for s in splits])
print(len(ddf))
ddf.head(2)

In [None]:
import duckdb
from lewidi_lib import process_rdf

con = duckdb.connect()
rdf = con.sql(
    "SELECT * FROM read_parquet('../parquets/*.parquet', union_by_name=True)"
).df()
rdf = process_rdf(rdf)
rdf.head(2)

In [None]:
import seaborn as sns

g = sns.relplot(
    data=rdf.query("dataset == 'CSC'"),
    x="model_size",
    y="is_valid_pred",
    hue="template_id",
    row="split",
    col="gen_kwargs",
    row_order=["train", "dev"],
    kind="line",
    marker="o",
    height=2.5,
    aspect=1.2,
)
g.set_axis_labels("Model Params [B]", "Valid Preds")
for ax in g.axes.flat:
    ax.set_ylim(0, 1.05)
    ax.grid(alpha=0.5)


In [None]:
logger.info(
    "Dropping %d predictions that don't sum to 1", len(rdf.query("~is_valid_pred"))
)
rdf.query("is_valid_pred", inplace=True)

In [None]:
joint_df = pd.merge(
    ddf[["dataset", "split", "request_idx", "target"]],
    rdf,
    on=["dataset", "split", "request_idx"],
)
joint_df.head(2)

In [31]:
from lewidi_lib import assign_col_pred_entropy, assign_col_ws_loss, assign_col_l0_loss

joint_df = assign_col_l0_loss(joint_df)
joint_df = assign_col_ws_loss(joint_df)
joint_df = assign_col_pred_entropy(joint_df)


In [32]:
from lewidi_lib import assign_n_classes, baseline_pred


bdf = assign_n_classes(ddf)
bdf = bdf.assign(pred=lambda row: row["n_classes"].apply(baseline_pred))
bdf = assign_col_ws_loss(bdf)
bdf = assign_col_l0_loss(bdf)

In [None]:
baseline_losses = bdf.groupby(["dataset", "split"], as_index=False).agg(
    {"ws_loss": "mean", "l0_loss": "mean"}
)
baseline_losses

In [34]:
from lewidi_lib import n_classes
import scipy
import scipy.stats

baseline_entropy = pd.DataFrame(
    {
        "entropy": [scipy.stats.entropy(baseline_pred(n_classes(d))) for d in datasets],
        "dataset": datasets,
    }
)

In [None]:
import seaborn as sns

g = sns.relplot(
    data=joint_df.query("dataset == 'CSC'"),
    x="model_size",
    y="ws_loss",
    hue="template_id",
    col="split",
    col_order=["train", "dev"],
    row="gen_kwargs",
    kind="line",
    marker="o",
    height=2.5,
    aspect=1.2,
    facet_kws={"sharey": False},
)
# g.set(ylim=(0, None))
g.set_axis_labels("Model Params [B]", "Wasserstein Distance")
for ax in g.axes.flat:
    ax.grid(alpha=0.5)


def plot_baseline_losses(g, baseline_losses, dataset: str | None = None):
    for ax in g.axes.flat:
        ax.grid(alpha=0.5)
        dataset_, split_ = ax.title.get_text().split(" | ")
        dataset_ = dataset_.split(" = ")[1]
        if dataset is not None:
            dataset_ = dataset
        split_ = split_.split(" = ")[1]
        baseline_ws_loss_ = baseline_losses.query(
            "dataset == @dataset_ and split == @split_"
        )["ws_loss"].values[0]
        ax.axhline(baseline_ws_loss_, color="red", linestyle="--", label="Baseline")


def plot_baseline_entropy(g, baseline_entropy):
    for ax in g.axes.flat:
        ax.grid(alpha=0.5)
        dataset_, _ = ax.title.get_text().split(" | ")
        dataset_ = dataset_.split(" = ")[1]
        value = baseline_entropy.query("dataset == @dataset_")["entropy"].values[0]
        ax.axhline(value, color="red", linestyle="--", label="Baseline")


plot_baseline_losses(g, baseline_losses, dataset="CSC")
