In [None]:
%load_ext autoreload
%autoreload 2
from logging import getLogger
from lewidi_lib import enable_logging
import pandas as pd
import seaborn as sns

pd.set_option("display.max_colwidth", 100)
enable_logging()
logger = getLogger(__name__)
sns.set_context("talk")

In [None]:
from lewidi_lib import assert_path_exists, preds_file

file = preds_file(
    dataset="MP",
    split="train",
    template="60",
    model_id="Qwen/Qwen3-32B",
    run_name="1000ex_10loops",
)
rdf = pd.read_parquet(assert_path_exists(file))

In [None]:
from lewidi_lib import (
    assign_cols_perf_metrics_softlabel,
    compute_diversity_by_problem,
    join_dataset,
    keep_only_highest_diversity_problems,
    process_rdf,
)

rdf = process_rdf(rdf, response_contains_steps=True)
answer_diversity = compute_diversity_by_problem(rdf)
rdf = rdf.merge(answer_diversity, on="dataset_idx")

In [None]:
joint_df = join_dataset(rdf)
joint_df = assign_cols_perf_metrics_softlabel(joint_df)
joint_df_subset = keep_only_highest_diversity_problems(joint_df)

In [None]:
joint_df.groupby("diversity", as_index=False, observed=True).agg(
    avg_ws_loss=("ws_loss", "mean"),
    avg_pairwise_ws_loss=("avg_pairwise_ws_loss", "mean"),
    n_examples=("dataset_idx", "count"),
).round(2)

In [None]:
# answer_diversity.query("diversity == 'Q5'")[["dataset_idx", "diversity"]].to_parquet("high_diversity_ids.parquet")

# """
# copy (
#     select rdf.*
#     from (select * from '../../1000ex_10loops/preds/responses.parquet') as rdf
#     join 'high_diversity_ids.parquet' as ids on rdf.dataset_idx = ids.dataset_idx
#     )
# to 'responses.parquet';
# """

# Load Ratings

In [None]:
from lewidi_lib import (
    assign_col_response_parsed,
    discard_na_response_rows,
    process_ratings,
)
import numpy as np
from prm800k import mapping


# ratings_file = "/Users/tomasruiz/datasets/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-32B/set2/t60/CSC/train/1000ex_10loops_high_diversity/judge/gemini-2.5-flash/responses.parquet"
# ratings_file = "/Users/tomasruiz/datasets/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-32B/set2/t60/CSC/train/1000ex_10loops/judge/Qwen/Qwen3-32B/set2/t23/1000ex_10loops/responses.parquet"
ratings_file = "/Users/tomasruiz/datasets/dss_home/lewidi-data/sbatch/di38bec/Qwen_Qwen3-32B/set2/t60/MP/train/1000ex_10loops/judge/Qwen/Qwen3-32B/set2/t23/1000ex_10loops/responses.parquet"
ratings = pd.read_parquet(ratings_file)
logger.info(
    "Loaded %d ratings for %d differnt dataset_idxs",
    len(ratings),
    ratings["dataset_idx"].nunique(),
)
ratings = discard_na_response_rows(ratings)
ratings = assign_col_response_parsed(ratings)
ratings = process_ratings(
    ratings, operation=np.mean, cat_mapping=mapping(ok=0.0, bad=0)
)
ratings = ratings[
    ["dataset_idx", "run_idx", "step_ratings", "score", "reasoning", "response_parsed"]
]
ratings.rename(columns={"reasoning": "judge_reasoning"}, inplace=True)

In [None]:
joint_df_subset = joint_df_subset.merge(
    ratings, on=["dataset_idx", "run_idx"], how="left"
)
joint_df_subset = discard_na_response_rows(joint_df_subset, col="score")

# BoN Loss Stats

In [None]:
len(joint_df_subset.drop_duplicates(subset=["dataset_idx", "run_idx"]))

In [None]:
from lewidi_lib import bootstrap_avg

joint_df_subset[["ws_loss", "score"]].apply(bootstrap_avg)

In [None]:
bon = joint_df_subset.loc[joint_df_subset.groupby("dataset_idx")["score"].idxmax()]
bon[["ws_loss", "score"]].apply(bootstrap_avg)

In [None]:
oracle = joint_df_subset.loc[joint_df_subset.groupby("dataset_idx")["ws_loss"].idxmin()]
oracle[["ws_loss", "score"]].apply(bootstrap_avg)

# Problem-Level Correlation

In [None]:
def corr(df):
    coeff = np.corrcoef(df["score"], df["ws_loss"])[0, 1]
    return coeff


corrs = (
    joint_df_subset.groupby("dataset_idx")[["score", "ws_loss"]].apply(corr).fillna(0)
)
bootstrap_avg(corrs)

# Individual Examples

In [None]:
# max_loss = joint_df_subset.loc[joint_df_subset.groupby("dataset_idx")["ws_loss"].idxmax()]
# row = max_loss.iloc[0]
# print(row["response"])