In [1]:
# import argparse
import gzip
import json
import random
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple

import blobfile as bf
import numpy as np
import orjson

Sample = Dict[str, Any]


def json_loads(s: str) -> Dict:
    try:
        return orjson.loads(s)
    except Exception:
        return json.loads(s)  # fallback


def open_jsonl(file: str):
    if file.endswith(".gz"):
        return gzip.open(bf.BlobFile(file, "rb"))
    return bf.BlobFile(file, "r")


def _read_jsonl(file: str) -> List[Dict]:
    assert bf.exists(file), file
    with open_jsonl(file) as f:
        return [json_loads(l) for l in f.readlines() if l]


def _key_by_problem(samples: List[Dict]):
    grouped_samples = defaultdict(list)
    for sample in samples:
        grouped_samples[sample["problem"]].append(sample)
    return grouped_samples


def _get_answer(sample: Sample) -> Optional[str]:
    return sample.get("answer", sample.get("given_answer", None))


def _choose_sample_by_score(samples: List[Sample], key: str) -> Optional[Sample]:
    if len(samples) == 0:
        return None
    return max(samples, key=lambda x: x[key])

In [2]:
# __DEBUG__ = True
__DEBUG__ = False
method = "prm"  # one of ['orm', 'prm']


if __DEBUG__:
    n_trials = 1
    ns = [10]
else:
    n_trials = 400
    ns = [10, 25, 50, 75, 100, 200, 300, 400, 500, 750, 1000, 1250, 1500, 1860]

In [3]:
# samples_path = "az://openaipublic/process-supervision/scored-test-samples.jsonl"
samples_path = "/data/tongyx361/reward-by-prm800k/datasets/scored-test-samples.jsonl"
num_samples_per_problem = 1860


print(f"Reading {samples_path}, this may take a while...")
samples = _read_jsonl(samples_path)
print("Done.")
samples_by_problem = _key_by_problem(samples)  # group samples by problem
num_problems = len(samples_by_problem)  # num of problmes

Reading /data/tongyx361/reward-by-prm800k/datasets/scored-test-samples.jsonl, this may take a while...
Done.


In [9]:
def get_n_subsamples(problem_samples, n):
    # # bug?
    if __DEBUG__:
        print("len(problem_samples)", len(problem_samples))
    nones = [None] * (
        num_samples_per_problem - len(problem_samples)
    )  # ::TODO:: 为什么要混入 None？
    problem_samples = problem_samples + nones
    random.shuffle(problem_samples)
    subsamples = list(problem_samples[:n])
    if __DEBUG__:
        print("len(subsamples)", len(subsamples))
    subsamples = [x for x in subsamples if x is not None]
    if __DEBUG__:
        print("len(subsamples_not_none)", len(subsamples))
    subsamples = [x for x in subsamples if _get_answer(x) is not None]
    if __DEBUG__:
        print("len(subsamples_with_answer)", len(subsamples))
    return subsamples

In [10]:
all_trial_pass_rates = []

for i in range(n_trials):
    pass_rates = []
    for n in ns:
        num_correct = 0
        for problem, problem_samples in samples_by_problem.items():
            subsamples = get_n_subsamples(problem_samples, n)

            if method == "prm":
                choice = _choose_sample_by_score(subsamples, "prm_score")
            elif method == "orm":
                choice = _choose_sample_by_score(subsamples, "orm_score")

            if choice is not None and choice["is_correct"]:
                num_correct += 1
        pass_rates.append(num_correct / num_problems)
    all_trial_pass_rates.append(pass_rates)
    print(f"Trial {i}/{n_trials} {pass_rates}")

len(problem_samples) 1858
len(subsamples) 10
len(subsamples_not_none) 10
len(subsamples_with_answer) 10
len(problem_samples) 1859
len(subsamples) 10
len(subsamples_not_none) 10
len(subsamples_with_answer) 10
len(problem_samples) 1854
len(subsamples) 10
len(subsamples_not_none) 10
len(subsamples_with_answer) 10
len(problem_samples) 1843
len(subsamples) 10
len(subsamples_not_none) 8
len(subsamples_with_answer) 8
len(problem_samples) 1860
len(subsamples) 10
len(subsamples_not_none) 10
len(subsamples_with_answer) 10
len(problem_samples) 1860
len(subsamples) 10
len(subsamples_not_none) 10
len(subsamples_with_answer) 10
len(problem_samples) 1859
len(subsamples) 10
len(subsamples_not_none) 10
len(subsamples_with_answer) 10
len(problem_samples) 1856
len(subsamples) 10
len(subsamples_not_none) 10
len(subsamples_with_answer) 10
len(problem_samples) 1852
len(subsamples) 10
len(subsamples_not_none) 10
len(subsamples_with_answer) 10
len(problem_samples) 1777
len(subsamples) 10
len(subsamples_not_no

In [5]:
np_all_trial_pass_rates = np.array(all_trial_pass_rates)
print("Mean:", list(np.mean(np_all_trial_pass_rates, axis=0)))
print("Standard deviation:", list(np.std(np_all_trial_pass_rates, axis=0)))

Mean: [0.666]
Standard deviation: [0.0]
