In [3]:
%load_ext autoreload
%autoreload 2

In [None]:
from lewidi_lib import (
    assign_cols_perf_metrics,
    enable_logging,
    join_correct_responses,
    load_preds,
    process_rdf,
)


enable_logging()
rdf = (
    load_preds(parquets_dir="../parquets")
    .tail(-200)
    .pipe(process_rdf, discard_invalid_pred=False)
)

In [5]:
rdf = rdf.assign(is_valid_pred=rdf["is_valid_pred"].astype(float))

In [None]:
# How do valid preds depend on the temperature, top_p, and presence penalty?
# They seem to have minimal effects on validity.
from matplotlib import pyplot as plt
import seaborn as sns

indep_vars = ["temperature", "top_p", "presence_penalty"]
fig, axes = plt.subplots(ncols=len(indep_vars), figsize=(12, 3))
for i, (dep_var, ax) in enumerate(zip(indep_vars, axes)):
    sns.regplot(
        rdf,
        x=dep_var,
        y="is_valid_pred",
        ax=ax,
    )
    ax.set_title(f"{dep_var} vs is_valid_pred")
    ax.set_xlabel(dep_var)
    ax.set_ylabel("is_valid_pred")


# Performance

In [None]:
vrdf = (
    process_rdf(rdf, discard_invalid_pred=True)
    .pipe(join_correct_responses)
    .pipe(assign_cols_perf_metrics)
)

In [None]:
sns.regplot(
    vrdf,
    x="presence_penalty",
    y="ws_loss",
    # lowess=True,
    order=5,
)
