In [7]:
import gzip
import json
import random
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple
import os
from dataclasses import dataclass

import blobfile as bf
import numpy as np
import orjson

Sample = Dict[str, Any]

## global variables

project_dir = "/data/users/zhangjunlei/tyx/reward-by-prm800k"

scored_test_samples_jsonl_path = os.path.join(project_dir, "datasets/scored-test-samples.jsonl")

prm800k_jsonl_dirpath = os.path.join(project_dir, "prm800k-main/prm800k/data")

prm800k_jsonl_path_phase = [
    {
        "train": os.path.join(prm800k_jsonl_dirpath, "phase1_train.jsonl"),
        "test": os.path.join(prm800k_jsonl_dirpath, "phase1_test.jsonl"),
    },
    {
        "train": os.path.join(prm800k_jsonl_dirpath, "phase2_train.jsonl"),
        "test": os.path.join(prm800k_jsonl_dirpath, "phase2_test.jsonl"),
    },
]

## functions


def json_loads(s: str) -> Dict:
    try:
        return orjson.loads(s)
    except Exception:
        return json.loads(s)  # fallback


def open_jsonl(file: str):
    if file.endswith(".gz"):
        return gzip.open(bf.BlobFile(file, "rb"))
    return bf.BlobFile(file, "r")


def read_jsonl(file: str) -> List[Dict]:
    assert bf.exists(file), file
    with open_jsonl(file) as f:
        return [json_loads(l) for l in f.readlines() if l]


def key_by_problem(samples: List[Dict]):
    grouped_samples = defaultdict(list)
    for sample in samples:
        if "problem" in sample:
            grouped_samples[sample["problem"]].append(sample)
        else:
            grouped_samples[sample["question"]["problem"]].append(sample)
    return grouped_samples

## classes

@dataclass
class MyDataset:
    filename: str
    samples: Any
    
    def __str__(self) -> str:
        return f"{self.filename}: {len(self.samples)} samples"

In [8]:
scored_test_samples = read_jsonl(scored_test_samples_jsonl_path)
print(random.choice(scored_test_samples))

prm800k_dataset_phase = []

for phase in prm800k_jsonl_path_phase:
    train_dataset = read_jsonl(phase["train"])
    test_dataset = read_jsonl(phase["test"])
    phase_dataset = {"train": train_dataset, "test": test_dataset}
    prm800k_dataset_phase.append(phase_dataset)

for phase_idx, phase_datasets in enumerate(prm800k_dataset_phase):
    print(f"PRM800K Phase {phase_idx + 1}")
    for split_name, samples in phase_datasets.items():
        print("\t", split_name, len(samples))
        # print(random.choice(dataset))
        

{'problem': 'The quadratic $x^2+(2.6)x+3.6$ can be written in the form $(x+b)^2+c$, where $b$ and $c$ are constants. What is $b+c$ (as a decimal)?', 'answer': '3.21', 'is_correct': True, 'subject': 'Algebra', 'level': 4, 'unique_id': 'test/algebra/621.json', 'steps': ['To write a quadratic in the form $(x+b)^2+c$, I need to complete the square.', 'That means I need to add and subtract the square of half the coefficient of $x$ inside the parentheses.', 'Half of $2.6$ is $1.3$, and its square is $1.69$.', 'So I can rewrite the quadratic as $x^2+(2.6)x+3.6=(x^2+(2.6)x+1.69)-1.69+3.6$.', 'Now I can factor the perfect square trinomial inside the parentheses as $(x+1.3)^2$.', 'Simplifying the constants outside the parentheses, I get $(x+1.3)^2+1.91$.', 'This is the form I wanted, where $b=1.3$ and $c=1.91$.', 'To find $b+c$, I just need to add these two decimals.\n\n# Answer\n\n3.21'], 'rating_probs': [{'1': 0.9701352340725533, '0': 0.02942827889870173, '-1': 0.00043648702874491644}, {'1': 0

In [9]:
datasets = []

scored_test_dataset = MyDataset(os.path.basename(scored_test_samples_jsonl_path), scored_test_samples)
datasets.append(scored_test_dataset)

for phase_idx, phase_datasets in enumerate(prm800k_dataset_phase):
    for split_name, samples in phase_datasets.items():
        dataset = MyDataset(os.path.basename(prm800k_jsonl_path_phase[phase_idx][split_name]), samples)
        datasets.append(dataset)

for dataset in datasets:
    print(dataset)

scored-test-samples.jsonl: 815631 samples
phase1_train.jsonl: 949 samples
phase1_test.jsonl: 106 samples
phase2_train.jsonl: 97782 samples
phase2_test.jsonl: 2762 samples


In [10]:
datasets_key_by_problem = [
    MyDataset(dataset.filename ,key_by_problem(dataset.samples)) for dataset in datasets
]

In [11]:
for dataset in datasets_key_by_problem:
    print(dataset)

scored-test-samples.jsonl: 500 samples
phase1_train.jsonl: 903 samples
phase1_test.jsonl: 101 samples
phase2_train.jsonl: 10828 samples
phase2_test.jsonl: 458 samples


In [12]:
# def get_dataset_sample_key_set(dataset : MyDataset):
#     return set(dataset.samples.keys())

def calculate_datasets_intersection_infos(datasets : List[MyDataset]):
    # 创建结果列表
    intersection_infos = [] # (dataset1, dataset2, intersection_size)

    # 遍历所有可能的集合组合
    for i in range(len(datasets)):
        for j in range(i+1, len(datasets)):
            # 找到两个集合的交集
            intersection = set(datasets[i].samples.keys()).intersection(set(datasets[j].samples.keys()))
            # 计算交集的大小并保存到结果列表中
            intersection_infos.append((datasets[i].filename, datasets[j].filename, len(intersection)))

    return intersection_infos

datasets_intersection_infos = calculate_datasets_intersection_infos(datasets_key_by_problem)

print(datasets_intersection_infos)

[('scored-test-samples.jsonl', 'phase1_train.jsonl', 0), ('scored-test-samples.jsonl', 'phase1_test.jsonl', 101), ('scored-test-samples.jsonl', 'phase2_train.jsonl', 0), ('scored-test-samples.jsonl', 'phase2_test.jsonl', 458), ('phase1_train.jsonl', 'phase1_test.jsonl', 0), ('phase1_train.jsonl', 'phase2_train.jsonl', 896), ('phase1_train.jsonl', 'phase2_test.jsonl', 0), ('phase1_test.jsonl', 'phase2_train.jsonl', 0), ('phase1_test.jsonl', 'phase2_test.jsonl', 101), ('phase2_train.jsonl', 'phase2_test.jsonl', 0)]


In [13]:
for info in datasets_intersection_infos:
    print(info)

('scored-test-samples.jsonl', 'phase1_train.jsonl', 0)
('scored-test-samples.jsonl', 'phase1_test.jsonl', 101)
('scored-test-samples.jsonl', 'phase2_train.jsonl', 0)
('scored-test-samples.jsonl', 'phase2_test.jsonl', 458)
('phase1_train.jsonl', 'phase1_test.jsonl', 0)
('phase1_train.jsonl', 'phase2_train.jsonl', 896)
('phase1_train.jsonl', 'phase2_test.jsonl', 0)
('phase1_test.jsonl', 'phase2_train.jsonl', 0)
('phase1_test.jsonl', 'phase2_test.jsonl', 101)
('phase2_train.jsonl', 'phase2_test.jsonl', 0)
