In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging

from lewidi_lib import enable_logging

enable_logging()

logger = logging.getLogger(__name__)

In [3]:
from prm800k import load_prm800k_phase2_dataset


dataset = load_prm800k_phase2_dataset(split="test")

import seaborn as sns

from prm800k import problems_with_50pct_correct_solutions


half_correct = problems_with_50pct_correct_solutions(dataset, n_problem_ids=10)

sns.lmplot(half_correct, x="avg_rating", y="correct", logistic=True)


# Join Gemini Ratings

In [None]:
import json_repair
import numpy as np
import pandas as pd

mapping = {"great": 1, "ok": 0, "okay": 0, "bad": -1}

gdf = pd.read_json("./gemini-prm800k-ratings.jsonl", lines=True)
logger.info("Dropping %d rows with success=False", len(gdf.query("not success")))
gdf = gdf.query("success")
gdf["ratings"] = (
    gdf["response"]
    .apply(json_repair.loads)
    .apply(lambda x: [mapping[r["rating"]] for r in x])
)
gdf["avg_rating"] = gdf["ratings"].apply(np.mean)
gdf["n_ratings"] = gdf["ratings"].apply(len)
gdf.head()

In [None]:
joined = dataset.merge(gdf, on="dataset_idx", suffixes=("_humans", "_gemini"))
joined["equal_n_ratings"] = joined["n_ratings_humans"] == joined["n_ratings_gemini"]
logger.info(
    "Dropping %d rows with unequal n_ratings", len(joined.query("not equal_n_ratings"))
)
joined = joined.query("equal_n_ratings")
joined.head(2)

for _, row in joined.query("~equal_n_ratings")[["texts", "response"]].iterrows():
    print(json.dumps(row["texts"], indent=2))
    print(row["response"])
    print("=" * 100)

In [None]:
cols = ["ratings_humans", "ratings_gemini"]
comparison = joined[cols].explode(column=cols).astype(int).reset_index(drop=True)
comparison.head(2)

In [None]:
(comparison["ratings_gemini"] == comparison["ratings_humans"]).mean().round(2)

In [None]:
comparison.corr()

In [None]:
import krippendorff

krippendorff.alpha(reliability_data=comparison.values.T).round(2)

comparison.value_counts()

# Human Annotator Agreement
This cannot be computed, because the different annotators received different LLM CoTs to rate.

from prm800k import load_raw_prm800k_phase2_dataset


prm800ktrain = load_raw_prm800k_phase2_dataset(split="train")

prm800ktrain.query("is_quality_control_question")

# Gemini BoN Sampling

In [None]:
(
    joined.groupby("problem_id", as_index=False)
    .agg(avg_correct=("correct", "mean"), avg_rating_humans=("avg_rating_humans", "mean"), avg_rating_gemini=("avg_rating_gemini", "mean"))
    .round(2)
)

In [None]:
# BoN based on Gemini ratings is pretty good!
joined.loc[joined.groupby("problem_id")["avg_rating_gemini"].idxmax()]["correct"].mean().round(2)

In [None]:
import plotly.express as px

fig = px.scatter(
    joined, 
    x="avg_rating_humans", 
    y="avg_rating_gemini", 
    color="correct",
    trendline="ols",
    title="Gemini vs Human Ratings by Correctness",
    hover_data=["dataset_idx"],
)
fig.show()