In [None]:
%load_ext autoreload
%autoreload 2
from lewidi_lib import enable_logging
import pandas as pd
import duckdb
import seaborn as sns

sns.set_context("talk")
enable_logging()

# Math Datasets

In [None]:
df = duckdb.sql("SELECT * FROM read_parquet('./tables/bon_samples_vs_perf/*')").df()
df = df.drop(columns=["__index_level_0__"])
df["Judge"] = df["judge"].apply(lambda s: s.split("/")[-1])
df["Dataset"] = df["dataset"].apply(lambda s: s.upper())

In [None]:
import numpy as np


def quantiles(xs):
    return np.quantile(xs, 0.1), np.quantile(xs, 0.9)

In [None]:
from lewidi_lib import plot_horizontal_lines

fgrid = sns.relplot(
    df,
    x="n_samples",
    y="is_correct",
    col="Dataset",
    col_order=["PRM800K", "AIME"],
    hue="Judge",
    style="Judge",
    markers=["o", "s", "D", "P"],
    kind="line",
    facet_kws={"sharey": False},
    # errorbar=quantiles,
)
fgrid.set_axis_labels("LLM samples $N$", "Correct answers")
sns.move_legend(fgrid, loc="lower left", bbox_to_anchor=(0.05, 0.95), ncol=3)
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5)

data = pd.DataFrame({"Dataset": ["PRM800K", "AIME"], "is_correct": [0.721, 0.639]})
plot_horizontal_lines(
    fgrid, data, label="Qwen3-32B Simple Sampling", color="blue", data_col="is_correct"
)

fgrid.savefig("./imgs/bon-eval/bon_samples_vs_perf_math.pdf", bbox_inches="tight")

# NLP Datasets

In [None]:
dfnlp = duckdb.sql(
    "SELECT * FROM read_parquet('./tables/bon_samples_vs_perf_nlp/*')"
).df()
dfnlp = dfnlp.drop(columns=["__index_level_0__"])
dfnlp["Judge"] = dfnlp["judge"].apply(lambda s: s.split("/")[-1])
dfnlp["Dataset"] = dfnlp["dataset"]

In [None]:
fgrid = sns.relplot(
    dfnlp,
    x="n_samples",
    y="ws_loss",
    col="Dataset",
    col_wrap=2,
    hue="Judge",
    style="Judge",
    markers=["o", "s", "D", "P"],
    kind="line",
    facet_kws={"sharey": False},
    height=4.0,
    aspect=1.2,
    # errorbar=quantiles
)
fgrid.set_axis_labels("LLM samples $N$", "Wasserstein Distance")
sns.move_legend(fgrid, loc="lower left", bbox_to_anchor=(0.25, 1.0))
for ax in fgrid.axes.flat:
    ax.grid(alpha=0.5)

data = pd.DataFrame(
    {
        "Dataset": ["CSC", "MP", "Paraphrase", "VariErrNLI"],
        "ws_loss": [1.175, 0.296, 2.48, 0.293],
    }
)
plot_horizontal_lines(
    fgrid, data, label="Qwen3-32B Simple Sampling", color="blue", data_col="ws_loss"
)

fgrid.savefig("./imgs/bon-eval/bon_samples_vs_perf_nlp.pdf", bbox_inches="tight")