In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from logging import getLogger
from lewidi_lib import (
    discard_failed_rows,
    discard_na_response_rows,
    enable_logging,
    join_dataset,
    preds_file,
)
import pandas as pd
import json_repair

enable_logging()
logger = getLogger(__name__)


def is_response_valid(response: dict) -> bool:
    return isinstance(response, dict) and "final_response" in response


rdf = pd.read_parquet("../prm800k-poc/preds/responses.parquet")
rdf = discard_failed_rows(rdf)
rdf = discard_na_response_rows(rdf)
rdf["response_parsed"] = rdf["response"].apply(json_repair.loads)
rdf["pred"] = rdf["response_parsed"].apply(lambda x: x["final_response"])

In [None]:
is_correct = pd.read_parquet("../prm800k-poc/judge/verify-solution/responses.parquet")
is_correct = discard_na_response_rows(is_correct)
is_correct = is_correct[["dataset_idx", "run_idx", "response", "reasoning"]].rename(
    columns={"response": "is_correct", "reasoning": "is_correct_reasoning"}
)

In [None]:
from lewidi_lib import assign_col_response_parsed, process_ratings
import numpy as np
from prm800k import mapping

ratings = pd.read_parquet("../prm800k-poc/judge/gemini-2.5-flash/responses.parquet")
ratings = discard_na_response_rows(ratings)
ratings = assign_col_response_parsed(ratings)
ratings = process_ratings(
    ratings, operation=np.prod, cat_mapping=mapping(ok=0.0, bad=0)
)
ratings = ratings[["dataset_idx", "run_idx", "score", "reasoning"]]
ratings.rename(columns={"reasoning": "judge_reasoning"}, inplace=True)

In [None]:
joint_df = join_dataset(rdf, parse_tgt=False)
joint_df = joint_df.merge(ratings, on=["dataset_idx", "run_idx"], how="left").merge(
    is_correct, on=["dataset_idx", "run_idx"]
)
joint_df = discard_na_response_rows(joint_df, col="score")

In [None]:
corr_df = joint_df.groupby("score", as_index=False).agg(
    is_correct=("is_correct", "mean"), n_examples=("dataset_idx", "count")
)
corr_df

In [None]:
import seaborn as sns

sns.scatterplot(corr_df, x="score", y="is_correct")

In [None]:
g_ = joint_df.groupby("dataset_idx")
g_[["score", "is_correct"]].mean().mean()

In [None]:
max_score_df = joint_df.loc[joint_df.groupby("dataset_idx")["score"].idxmax()]
max_score_df[["score", "is_correct"]].mean()

In [None]:
import duckdb

crosstab = duckdb.sql("PIVOT joint_df ON is_correct GROUP BY dataset_idx").df()
crosstab = crosstab.rename(columns={"0": "incorrect", "1": "correct"})
crosstab["all_incorrect"] = (crosstab["incorrect"] > 0) & (crosstab["correct"] == 0)
crosstab["all_correct"] = (crosstab["incorrect"] == 0) & (crosstab["correct"] > 0)
crosstab["mixed"] = (crosstab["incorrect"] > 0) & (crosstab["correct"] > 0)
assert crosstab[["all_incorrect", "all_correct", "mixed"]].sum().sum() == len(crosstab)

In [None]:
crosstab.melt(
    "dataset_idx", value_vars=["all_incorrect", "all_correct", "mixed"], var_name="type"
).query("value")["type"].value_counts()

In [None]:
mixd_perf = crosstab.query("mixed")
mixd_perf

In [None]:
joint_df.query("dataset_idx == 55")[
    [
        "run_idx",
        "pred",
        "target",
        "is_correct",
        "score",
        "response",
        "is_correct_reasoning",
        "judge_reasoning",
    ]
]

# Mixed Perf Cases

In [None]:
from lewidi_lib import bootstrap_avg


mp_cases = joint_df.query("dataset_idx in @mixd_perf.dataset_idx")
mp_cases.groupby("dataset_idx")[["score", "is_correct"]].mean().apply(bootstrap_avg)

In [None]:
mp_cases_bon = mp_cases.loc[mp_cases.groupby("dataset_idx")["score"].idxmax()]
mp_cases_bon[["score", "is_correct"]].apply(bootstrap_avg)

In [None]:
def corr(df):
    coeff = np.corrcoef(df["score"], df["is_correct"])[0, 1]
    return coeff
mp_cases.groupby("dataset_idx").apply(corr).mean()