In [1]:
OUTPUT_FILE = "grpo_tweak_dataset.jsonl"

In [2]:
import os
import pandas as pd
import glob
import json
import re
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

## Load Completions

In [3]:
def load_all_jsonl_to_dataframe(folder_path):
    """Load all JSONL files from a folder into a pandas DataFrame."""
    # Find all .jsonl files
    jsonl_files = glob.glob(os.path.join(folder_path, "*.jsonl"))
    
    # Create list to hold dataframes
    dfs = []
    
    for file_path in jsonl_files:
        try:
            # Read file into dataframe
            df = pd.read_json(file_path, lines=True)
            df['source_file'] = os.path.basename(file_path)  # Optional: track source file
            dfs.append(df)
            print(f"Loaded {len(df)} rows from {os.path.basename(file_path)}")
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
    
    # Combine all dataframes
    if dfs:
        combined_df = pd.concat(dfs, ignore_index=True)
        print(f"Created DataFrame with {len(combined_df)} rows from {len(jsonl_files)} files")
        return combined_df
    else:
        print("No valid files found")
        return pd.DataFrame()

# Usage
folder_path = "../data/DeepScaleR_1_5_B_results/"
combined_data = load_all_jsonl_to_dataframe(folder_path)

Loaded 12 rows from 4.jsonl
Loaded 12 rows from 6.jsonl
Loaded 12 rows from 2.jsonl
Loaded 12 rows from 0.jsonl
Loaded 12 rows from 7.jsonl
Loaded 12 rows from 5.jsonl
Loaded 12 rows from 1.jsonl
Loaded 12 rows from 3.jsonl
Created DataFrame with 96 rows from 8 files


## Display sample data

In [15]:
# Display the first few rows
# combined_data.head(100)

In [16]:
# combined_data.iloc[0]

## Solution Extraction Utilities

Functions to extract solutions from model outputs and compute correctness.

In [6]:
def remove_boxed(s):
    if s is None:
        return None

    left = "\\boxed{"

    if s[:len(left)] != left:
        return None
    if s[-1] != "}":
        return None

    return s[len(left):-1]


def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval

def extract_solution(text):
    return remove_boxed(last_boxed_only_string(text))

def correctness_reward_func(response: str, actual_answers: str) -> float:
    extracted_answer = extract_solution(response)
    return 1.0 if extracted_answer == actual_answers else 0.0

## Transform for Analysis

In [17]:
check = combined_data.copy()[['prompt', 'answer']].drop_duplicates()
# print(f"len(check): {len(check)}")

In [19]:
df = combined_data.copy()
df['question_text'] = df['prompt']
df['ground_truth'] = df['answer']
df.drop([
    'source_file',
], axis=1, inplace=True)
# print(f"len(df): {len(df)}")
df = df.explode('outputs')
df['response'] = df['outputs'].apply(lambda x: x['output'])
# print(f"len(df): {len(df)}")
df['extracted_solution'] = df['response'].apply(lambda x: extract_solution(x))
df['response_word_count'] = df['response'].astype(str).str.split().str.len()

# # # Only keep completions where there was a generated solution.
# # df = df[df['extracted_solution'].notna()]

df['reward'] = (df['extracted_solution'] == df['ground_truth'].astype(str)).astype(float)
# df
# df[df['extracted_solution'].isna()]

In [9]:
def pass_at_k(n, c, k):
    """
    :param n: total number of samples :param c: number of correct
    Samples
    :param k: k in pass@$k$
    """
    if n - c < k: return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

# pass_at_k(16, 2, 1)

In [20]:
perf = (
    df[['question_text', 'ground_truth', 'reward']]
    .groupby(['question_text', 'ground_truth'])
    .agg(
        count=('reward', 'count'),
        sum_reward=('reward', 'sum'),
        accuracy=('reward', 'mean')
    )
)

print(f"len(perf): {len(perf)}")

# If only answer with an extarcted \boxed{...} answer are kept, count isn't always 16.
# assert(len(perf[perf['count'] != 16]) == 0)
# perf[f'pass@{k}'] = perf['sum_reward'].apply(lambda x: pass_at_k(16, x, k))

k = 8
perf[f'pass@{k}'] = perf.apply(
    lambda x: pass_at_k(x['count'], x['sum_reward'], k), 
    axis=1
)
# perf

len(perf): 87


In [21]:
# overall_perf = perf[[f'pass@{k}']].mean()
# overall_perf

In [23]:
total = len(perf)
print(f"total: {total}")

# Too hard.
too_hard = perf[(perf['accuracy'] == 0.0)]
total_too_hard = len(too_hard)
print(f"total_too_hard: {total_too_hard}")

# Ripe for improvement
ripe = perf[(perf['accuracy'] <= 0.8) & (perf['accuracy'] > 0.0)]
total_ripe = len(ripe)
print(f"total_ripe: {total_ripe}")

# Not worth improving.
already_done = perf[(perf['accuracy'] >= 0.8)]
total_already_done = len(already_done)
print(f"total_already_done: {total_already_done}")

check_total = total_too_hard + total_ripe + total_already_done
assert(total == check_total)

# ripe

total: 87
total_too_hard: 38
total_ripe: 16
total_already_done: 33


In [25]:
tweak_dataset = [
    {
        'problem': r['question_text'],
        'answer': str(r['ground_truth']),
    }
    for index, r in ripe.reset_index().iterrows()
]
# len(tweak_dataset)

In [26]:
# Create directory if it doesn't exist
dirname = os.path.dirname(OUTPUT_FILE)
if len(dirname.strip()) > 0:
    os.makedirs(dirname, exist_ok=True)

# Save to JSONL file
with open(OUTPUT_FILE, 'w') as f:
    for item in tweak_dataset:
        f.write(json.dumps(item) + '\n')

print(f"Saved {len(tweak_dataset)} records to {OUTPUT_FILE}")

Saved 16 records to grpo_tweak_dataset.jsonl
