# Prepare MATH-500 Minus MATH-500

## Load datasets

In [1]:
# !uv pip install load_dataset

In [None]:
from datasets import load_dataset

# Full MATH dataset
math_full = load_dataset("qwedsacf/competition_math", split="train")

# MATH-500 benchmark
math_500 = load_dataset("HuggingFaceH4/MATH-500", split="test")

## Extract answers

In [None]:
# !uv pip install reasoning_from_scratch

In [None]:
# Test extracting the "answer"

from reasoning_from_scratch.ch03 import (
    extract_final_candidate,
)

extract_final_candidate(
        math_500[0]["solution"], fallback=None
    )

In [None]:
# Check the reference "answer"

math_500[0]["answer"]

In [None]:
# Apply extraction to full math dataset, which
# doesn't have "answer" fields, yet

def add_answer_and_id(example, idx):
    example["answer"] = extract_final_candidate(
        example["solution"], fallback=None
    )
    example["unique_id"] = idx
    return example

math_full = math_full.map(add_answer_and_id, with_indices=True)

## Check extraction

In [None]:
# Get MATH-500 portion for testing purposes

math500_lookup = {
    p.strip(): a
    for p, a in zip(math_500["problem"], math_500["answer"])
}

In [None]:
# Check if there were any bad extractions
# (we use MATH-500 as a gold standard, if there are no
# mismatches, it is reasonable to assume that the extraction
# method works reliably)

mismatches = []

for ex in math_full:
    problem = ex["problem"].strip()

    if problem in math500_lookup:
        answer_full = ex["answer"]
        answer_500 = math500_lookup[problem]

        if answer_full != answer_500:
            mismatches.append({
                "problem": problem,
                "answer_full": answer_full,
                "answer_500": answer_500,
            })

In [None]:
print("Overlapping problems:", len(mismatches) + sum(
    ex["problem"].strip() in math500_lookup for ex in math_full
))
print("Mismatches:", len(mismatches))

## Filter out MATH-500 test set

In [None]:
# Remove MATH-500 test set from full MATH dataset

math500_problems = set(
    p.strip() for p in math_500["problem"]
)

def is_not_in_math500(example):
    return example["problem"].strip() not in math500_problems

math_train_filtered = math_full.filter(is_not_in_math500)

In [None]:
print(len(math_full), "->", len(math_train_filtered))

In [None]:
math_train_filtered

In [None]:
math_full.to_json(
    "math_full_minus_math500.json",
    orient="records",
    indent=2,
)

In [None]:
math_full.to_json(
    "math_full.json",
    orient="records",
    indent=2,
)