In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from lewidi_lib import (
    assign_cols_perf_metrics,
    enable_logging,
    join_correct_responses,
    load_preds,
    make_query_from_dict,
    process_rdf,
)
import pandas as pd

enable_logging()


ratings = pd.read_json(
    "../parquets/reasoning-ratings/reasoning-judge-responses.jsonl", lines=True
)

metadata = {
    "template_id": 31,
    "model_id": "Qwen/Qwen3-32B",
    "gen_kwargs": "set2",
    "dataset": "CSC",
    "judge_model_id": "gemini-2.5-pro-preview-06-05",
}

rdf = load_preds(parquets_dir="../parquets")
query = make_query_from_dict(metadata, rdf.columns)
rdf = rdf.query(query)
rdf = process_rdf(rdf)
rdf = join_correct_responses(rdf)
rdf = assign_cols_perf_metrics(rdf)

In [55]:
join_cols = ["dataset", "dataset_idx", "run_idx"]  # expand when more cols!
ratings_cols = [
    "response",
    "reasoning",
    "judge_model_id",
    "dataset",
    "dataset_idx",
    "run_idx",
]
joint = ratings[ratings_cols].merge(
    rdf, on=join_cols, how="inner", suffixes=("_judge", "")
)

In [None]:
# example 8: too lax with spread out distribution
for k, v in (
    joint.iloc[37][
        ["text", "reasoning", "response", "target", "reasoning_judge", "response_judge"]
    ]
    .to_dict()
    .items()
):
    print(k, v)
    print("-" * 100)

In [None]:
import seaborn as sns

sns.set_theme(style="whitegrid", context="talk")

fgrid = sns.jointplot(
    joint, x="response_judge", y="ws_loss", scatter_kws={"alpha": 0.5}, kind="reg"
)

In [None]:
(joint.groupby("dataset_idx").size() == 10).all()

In [None]:
# there is almost no performance difference between the normal outputs
# and those selected for top trace ratings
avg_ws_loss = joint.groupby("dataset_idx", as_index=False)["ws_loss"].mean()
best_by_judge = joint.loc[joint.groupby("dataset_idx")["response_judge"].idxmax()][
    ["dataset_idx", "ws_loss"]
]
best_by_judge["ws_loss"].mean(), avg_ws_loss["ws_loss"].mean()

In [None]:
ws_comparison = pd.concat(
    [avg_ws_loss.assign(type="avg"), best_by_judge.assign(type="best_by_judge")],
    ignore_index=True,
)
ax = sns.histplot(ws_comparison, x="ws_loss", hue="type", kde=True)