In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import logging

logger = logging.getLogger(__name__)

In [None]:
from lewidi_lib import load_dataset, enable_logging

enable_logging()

datasets = ["CSC", "MP"]
splits = ["train", "dev"]
ddf = pd.concat([load_dataset(d, split=s) for d in datasets for s in splits])
print(len(ddf))
ddf.head(2)

In [None]:
import duckdb
from lewidi_lib import process_rdf

con = duckdb.connect()
rdf = con.sql(
    "SELECT * FROM read_parquet('../parquets/*.parquet', union_by_name=True)"
).df()
rdf = process_rdf(rdf)
rdf.head(2)

In [None]:
import seaborn as sns

g = sns.relplot(
    data=rdf,
    x="model_size",
    y="is_valid_pred",
    hue="gen_kwargs",
    col="dataset",
    row="split",
    row_order=["train", "dev"],
    kind="line",
    marker="o",
    height=2.5,
    aspect=1.2,
)
g.set_axis_labels("Model Params [B]", "Valid Preds")
g.legend.set_title("Reasoning")
for ax in g.axes.flat:
    ax.set_ylim(0, 1.05)
    ax.grid(alpha=0.5)


In [None]:
logger.info(
    "Dropping %d predictions that don't sum to 1", len(rdf.query("~is_valid_pred"))
)
rdf.query("is_valid_pred", inplace=True)

In [None]:
joint_df = pd.merge(
    ddf[["dataset", "split", "request_idx", "target"]],
    rdf,
    on=["dataset", "split", "request_idx"],
)
joint_df.head(2)

In [15]:
from lewidi_lib import assign_col_pred_entropy, assign_col_ws_loss, assign_col_l0_loss

joint_df = assign_col_l0_loss(joint_df)
joint_df = assign_col_ws_loss(joint_df)
joint_df = assign_col_pred_entropy(joint_df)


# Baseline: Uniform Distribution

In [16]:
from lewidi_lib import assign_n_classes, baseline_pred


bdf = assign_n_classes(ddf)
bdf = bdf.assign(pred=lambda row: row["n_classes"].apply(baseline_pred))
bdf = assign_col_ws_loss(bdf)
bdf = assign_col_l0_loss(bdf)

In [None]:
baseline_losses = bdf.groupby(["dataset", "split"], as_index=False).agg(
    {"ws_loss": "mean", "l0_loss": "mean"}
)
baseline_losses

In [18]:
from lewidi_lib import n_classes
import scipy
import scipy.stats

baseline_entropy = pd.DataFrame(
    {
        "entropy": [scipy.stats.entropy(baseline_pred(n_classes(d))) for d in datasets],
        "dataset": datasets,
    }
)

# Is Performance Correlated With Size?

In [None]:
from lewidi_lib import plot_baseline_losses
import seaborn as sns

g = sns.relplot(
    data=joint_df,
    x="model_size",
    y="ws_loss",
    hue="gen_kwargs",
    col="split",
    col_order=["train", "dev"],
    row="dataset",
    row_order=["CSC", "MP"],
    kind="line",
    marker="o",
    height=2.5,
    aspect=1.2,
    facet_kws={"sharey": False},
)
# g.set(ylim=(0, None))
g.set_axis_labels("Model Params [B]", "Wasserstein Distance")
plot_baseline_losses(g, baseline_losses)


# Is performance correlated with avg entropy?

In [None]:
g = sns.relplot(
    data=joint_df,
    x="model_size",
    y="pred_entropy",
    hue="gen_kwargs",
    col="split",
    col_order=["train", "dev"],
    row="dataset",
    kind="line",
    # style="gen_kwargs",
    marker="o",
    height=2.5,
    aspect=1.2,
    facet_kws={"sharey": False, "sharex": True},
)
for ax in g.axes.flat:
    ax.grid(alpha=0.5)
plot_baseline_entropy(g, baseline_entropy)

In [None]:
ent_df = joint_df.groupby(
    ["model_size", "gen_kwargs", "dataset", "split"], as_index=False
).agg(
    avg_entropy=("pred_entropy", "mean"),
    avg_ws_loss=("ws_loss", "mean"),
)

g = sns.relplot(
    ent_df,
    x="avg_entropy",
    y="avg_ws_loss",
    hue="model_size",
    style="gen_kwargs",
    col="split",
    col_order=["train", "dev"],
    row="dataset",
    row_order=["CSC", "MP"],
    kind="scatter",
    height=2.5,
    aspect=1.2,
    facet_kws={"sharey": False, "sharex": False},
    palette="viridis",
)
for ax in g.axes.flat:
    ax.grid(alpha=0.5)
plot_baseline_losses(g, baseline_losses)
