In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from logging import getLogger
from lewidi_lib import (
    discard_failed_rows,
    discard_na_response_rows,
    enable_logging,
    join_dataset,
    preds_file,
)
import pandas as pd
import json_repair

enable_logging()
logger = getLogger(__name__)


def is_response_valid(response: dict) -> bool:
    return isinstance(response, dict) and "final_response" in response


path = preds_file(
    dataset="prm800k",
    split="train",
    template="31",
    model_id="Qwen/Qwen3-32B",
    run_name="allex_10loops",
    format="jsonl",
)
rdf = pd.read_json(path, lines=True)
rdf = pd.read_parquet(
    "/home/tomasruiz/code/lewidi2025/prm800k-poc/preds/responses.parquet"
)
ratings = pd.read_parquet(
    "/home/tomasruiz/code/lewidi2025/prm800k-poc/judge/responses.parquet"
)
is_correct = pd.read_parquet(
    "/home/tomasruiz/code/lewidi2025/prm800k-poc/judge/verify-solution/responses.parquet"
)
ratings = discard_na_response_rows(ratings)
rdf = discard_na_response_rows(rdf)
is_correct = discard_na_response_rows(is_correct)
rdf["response_parsed"] = rdf["response"].apply(json_repair.loads)
rdf["pred"] = rdf["response_parsed"].apply(lambda x: x["final_response"])
# rdf["is_response_valid"] = rdf["response_parsed"].apply(is_response_valid)
# rdf = discard_failed_rows(rdf, col="is_response_valid")

In [None]:
# import sympy
# import sympy.parsing
# import sympy.parsing.latex
# from sympy import Symbol, pi


# def are_sympy_expr_equal(prediction: str, solution: str) -> bool:
#     success1, pred = try_parse(prediction)
#     success2, sol = try_parse(solution)
#     if not success1 or not success2:
#         return False
#     if isinstance(sol, tuple):
#         return sol == pred
#     equal = pred.equals(sol)
#     if equal is None:
#         return False
#     return equal


# def try_parse(s: str) -> tuple[bool, sympy.Expr]:
#     """Return (success, sympy_expr)"""
#     try:
#         expr = sympy.parsing.parse_expr(s)
#         return True, replace_pi(expr)
#     except Exception:
#         try:
#             expr = sympy.parsing.latex.parse_latex(s)
#             return True, replace_pi(expr)
#         except Exception:
#             print(f"Could not parse expression: {s}")
#             return False, None


# def replace_pi(expr: sympy.Expr) -> sympy.Expr:
#     if isinstance(expr, tuple):
#         return expr
#     return expr.subs({Symbol("pi"): pi})


# joint_df["is_correct"] = joint_df.apply(
#     lambda row: are_sympy_expr_equal(row["pred"], row["target"]), axis=1
# )

In [None]:
from lewidi_lib import assing_col_score_from_json

ratings = assing_col_score_from_json(ratings)
ratings = ratings[["dataset_idx", "run_idx", "score"]]
is_correct = is_correct[["dataset_idx", "run_idx", "response"]].rename(
    columns={"response": "is_correct"}
)
joint_df = join_dataset(rdf, parse_tgt=False)
joint_df = joint_df.merge(ratings, on=["dataset_idx", "run_idx"]).merge(
    is_correct, on=["dataset_idx", "run_idx"]
)

In [None]:
corr_df = joint_df.groupby("score", as_index=False).agg(
    is_correct=("is_correct", "mean"), n_examples=("dataset_idx", "count")
)

In [None]:
import seaborn as sns

sns.scatterplot(corr_df, x="score", y="is_correct")

In [None]:
joint_df["is_correct"].mean(), joint_df["score"].mean()

In [None]:
max_score_df = joint_df.loc[joint_df.groupby("dataset_idx")["score"].idxmax()]
max_score_df["is_correct"].mean(), max_score_df["score"].mean()

In [None]:
import duckdb

improvable_df = duckdb.sql("PIVOT joint_df ON is_correct GROUP BY dataset_idx").df()
improvable_df = improvable_df.rename(columns={"0": "incorrect", "1": "correct"})
improvable_df["all_incorrect"] = (improvable_df["incorrect"] > 0) & (
    improvable_df["correct"] == 0
)
improvable_df.query("incorrect > 0")

In [None]:
joint_df.query("dataset_idx == 10")[["pred", "target", "is_correct", "score"]]