In [None]:
%load_ext autoreload
%autoreload 2
from lewidi_lib import enable_logging

enable_logging()

In [None]:
from pathlib import Path
import duckdb
from lewidi_lib import preds_file
import pandas as pd

datasets = ["MP", "CSC", "Paraphrase", "VariErrNLI", "prm800k", "aime"]


def qwen32b_preds_file(dataset: str) -> Path:
    if is_math(dataset):
        run_name = "allex_10loops"
    else:
        run_name = "1000ex_10loops"
    return preds_file(
        dataset=dataset,
        split="train",
        template="60",
        model_id="Qwen/Qwen3-32B",
        run_name=run_name,
    )


judge = "Qwen/Qwen3-32B"

judge_file_nlp = {
    "gemini-2.5-flash": "1000ex_10loops/judge/gemini-2.5-flash/t24/responses.parquet",
    "Qwen/Qwen3-32B": "1000ex_10loops/judge/Qwen/Qwen3-32B/set2/t24/1000ex_10loops_q5div/responses.parquet",
    "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B": "1000ex_10loops/judge/deepseek-ai/DeepSeek-R1-0528-Qwen3-8B/set2/t24/1000ex_10loops_q5div/responses.parquet",
}
judge_to_file_math = {
    "gemini-2.5-flash": "allex_10loops_mixed_perf_subset/judge/gemini-2.5-flash/t24/allex_10loops_mp/responses.parquet",
    "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B": "allex_10loops_mixed_perf_subset/judge/deepseek-ai/DeepSeek-R1-0528-Qwen3-8B/set2/t24/allex_10loops_mp/responses.parquet",
    "Qwen/Qwen3-32B": "allex_10loops_mixed_perf_subset/judge/Qwen/Qwen3-32B/set2/t24/allex_10loops_mp/responses.parquet",
}


def judge_to_file(judge, dataset):
    if is_math(dataset):
        return judge_to_file_math[judge]
    else:
        return judge_file_nlp[judge]


def is_math(dataset: str) -> bool:
    return dataset.lower() in ["prm800k", "aime"]


def assign_col_domain(df: pd.DataFrame) -> pd.DataFrame:
    col = df["dataset"].apply(lambda x: "Math" if is_math(x) else "LeWiDi")
    return df.assign(domain=col)


preds_files = []
judge_files = []
for dataset in datasets:
    pfile = qwen32b_preds_file(dataset)
    preds_files.append(pfile)
    assert pfile.exists()
    for judge in judge_file_nlp.keys():
        judge_file = pfile.parent.parent.parent / judge_to_file(judge, dataset)
        assert judge_file.exists(), f"{judge_file} does not exist"
        judge_files.append(judge_file)

In [None]:
from lewidi_lib import compact_model_name


ds_map = {"prm800k": "PRM800K", "aime": "AIME"}
block_names = {"response_len": "response", "reasoning_len": "reasoning"}


def apply_colmaps(df: pd.DataFrame) -> pd.DataFrame:
    return df.assign(
        model_id=df["model_id"].map(compact_model_name),
        dataset=df["dataset"].map(lambda s: ds_map.get(s, s)),
        # block=len_data_judge["block"].map(block_names)
    )

# Preds Stats

In [None]:
from lewidi_lib import load_listof_parquets


cols = ["dataset", "model_id", "n_output_tokens", "response", "dataset_idx"]
rdf = load_listof_parquets(preds_files).query("success")[cols]
rdf = rdf.drop(columns=["response"]).assign(model="LLM")
rdf.head(2)

In [None]:
cols = ["dataset", "judge_model_id", "n_output_tokens", "response", "dataset_idx"]
jdf = load_listof_parquets(judge_files).query(
    "success and judge_model_id != 'gemini-2.5-flash'"
)[cols]
jdf = jdf.rename(columns={"judge_model_id": "model_id"}).drop(columns=["response"])
jdf = jdf.assign(model="Judge")
jdf.head(2)

In [None]:
joint = pd.concat([rdf, jdf], ignore_index=True)
joint = apply_colmaps(joint)
joint = joint.assign(title=joint["model"] + ": " + joint["model_id"])
joint = assign_col_domain(joint)
joint

In [None]:
import seaborn as sns
import numpy as np

sns.set_context("talk")


def plot_num_chars(len_data):
    fgrid = sns.catplot(
        len_data,
        y="dataset",
        x="n_output_tokens",
        hue="domain",
        kind="bar",
        col="title",
        # showfliers=False,
        margin_titles=True,
        errorbar=lambda x: (np.quantile(x, 0.25), np.quantile(x, 0.75)),
        sharex="col",
        height=4,
        aspect=1,
    )
    fgrid.set_titles(col_template="{col_name}")
    fgrid.set_axis_labels("Output Tokens", "Dataset")
    sns.move_legend(
        fgrid, loc="lower left", bbox_to_anchor=(0.3, 1.0), ncol=2, title="Domain"
    )
    for ax in fgrid.axes.flat:
        ax.grid(alpha=0.5, axis="x")
    return fgrid


fgrid = plot_num_chars(joint)

fgrid.savefig(
    "imgs/domain_comp/lens-of-responses-and-reasonings.pdf", bbox_inches="tight"
)