In [5]:
# load all JSONL files in the directory
import os
import json

dir_path = "/mnt/fast10/brandon/mmr_rollout_data/prm_training_data/debug"

first_jsonl = next((f for f in os.listdir(dir_path) if f.endswith(".jsonl")), None)
if first_jsonl:
    with open(os.path.join(dir_path, first_jsonl), "r") as f:
        first_line = f.readline()
        if first_line:
            columns = list(json.loads(first_line).keys())
        else:
            columns = []
else:
    columns = []

print(columns)

['id', 'image_url', 'conversations', 'first_incorrect_step', 'steps_with_score', 'consensus_filtering_algo_label', 'verifier_identified_first_incorrect_step_solution']


In [6]:
from collections import defaultdict, Counter
import json
import os

# First pass: Check schema consistency across all files
print("Checking schema consistency...")
file_schemas = {}
jsonl_files = [f for f in os.listdir(dir_path) if f.endswith(".jsonl")]

for file in jsonl_files:
    with open(os.path.join(dir_path, file), "r") as f:
        first_line = f.readline()
        if first_line:
            schema = set(json.loads(first_line).keys())
            file_schemas[file] = schema

# Check if all schemas are identical
reference_schema = next(iter(file_schemas.values()))
inconsistent_files = []
for file, schema in file_schemas.items():
    if schema != reference_schema:
        inconsistent_files.append((file, schema))

if inconsistent_files:
    print("WARNING: Schema inconsistencies found!")
    for file, schema in inconsistent_files:
        print(f"  {file}: {schema}")
        print(f"    Missing: {reference_schema - schema}")
        print(f"    Extra: {schema - reference_schema}")
else:
    print("✓ All files have consistent schema")

Checking schema consistency...
✓ All files have consistent schema


In [8]:
# Statistics collection
file_stats = {}
all_data = []

print(f"\nProcessing {len(jsonl_files)} files...")

for file in jsonl_files:
    print(f"Processing {file}...")
    
    # Initialize stats for this file
    first_incorrect_step_sections = Counter()
    consensus_filtering_labels = Counter()
    line_count = 0
    
    with open(os.path.join(dir_path, file), "r") as f:
        for line in f:
            item = json.loads(line)
            all_data.append(item)
            line_count += 1
            
            # Collect first_incorrect_step section distribution
            if "first_incorrect_step" in item and item["first_incorrect_step"] is not None:
                if isinstance(item["first_incorrect_step"], (list, tuple)) and len(item["first_incorrect_step"]) >= 1:
                    section_name = item["first_incorrect_step"][0]
                    first_incorrect_step_sections[section_name] += 1
            
            # Collect consensus_filtering_algo_label distribution
            if "consensus_filtering_algo_label" in item and item["consensus_filtering_algo_label"] is not None:
                consensus_filtering_labels[item["consensus_filtering_algo_label"]] += 1
    
    # Store stats for this file
    file_stats[file] = {
        "line_count": line_count,
        "first_incorrect_step_sections": dict(first_incorrect_step_sections),
        "consensus_filtering_labels": dict(consensus_filtering_labels)
    }
    
    print(f"  Lines processed: {line_count}")

# Write flattened data to single file
output_file = os.path.join(dir_path, "flattened_all_data.jsonl")
print(f"\nWriting flattened data to {output_file}...")

with open(output_file, "w") as f:
    for item in all_data:
        f.write(json.dumps(item) + "\n")

print(f"✓ Flattened {len(all_data)} total records to {output_file}")

# Print statistics
print("\n" + "="*60)
print("STATISTICS BY FILE")
print("="*60)

for file, stats in file_stats.items():
    print(f"\n{file}:")
    print(f"  Total records: {stats['line_count']}")
    
    # Get key counts for percentage calculations
    incorrect_count = stats['consensus_filtering_labels'].get('o4-mini_incorrect_and_MC_agrees_and_disagrees', 0)
    correct_count = stats['consensus_filtering_labels'].get('o4-mini_correct_and_MC_agrees', 0)
    unused_count = stats['consensus_filtering_labels'].get('o4-mini_correct_and_MC_disagrees', 0)
    
    reasoning_count = stats['first_incorrect_step_sections'].get('Reasoning', 0)
    visual_count = stats['first_incorrect_step_sections'].get('Visual Elements', 0)
    
    print(f"  First incorrect step sections:")
    for section, count in sorted(stats['first_incorrect_step_sections'].items()):
        if incorrect_count > 0:
            pct = (count / incorrect_count) * 100
            print(f"    {section}: {count} ({pct:.1f}% of incorrect samples)")
        else:
            print(f"    {section}: {count} (no incorrect samples)")
    
    print(f"  Consensus filtering labels:")
    for label, count in sorted(stats['consensus_filtering_labels'].items()):
        print(f"    {label}: {count}")
    
    # Training sample breakdown
    training_total = correct_count + incorrect_count
    print(f"  Training samples breakdown:")
    print(f"    Used for training: {training_total} ({correct_count} correct + {incorrect_count} incorrect)")
    print(f"    Not used for training: {unused_count} (o4-mini_correct_and_MC_disagrees)")
    
    if training_total > 0:
        correct_pct = (correct_count / training_total) * 100
        incorrect_pct = (incorrect_count / training_total) * 100
        print(f"    Training split: {correct_pct:.1f}% correct, {incorrect_pct:.1f}% incorrect")
    else:
        print(f"    Training split: No training samples")

# Overall statistics
print("\n" + "="*60)
print("OVERALL STATISTICS")
print("="*60)

all_sections = Counter()
all_labels = Counter()

for stats in file_stats.values():
    for section, count in stats['first_incorrect_step_sections'].items():
        all_sections[section] += count
    for label, count in stats['consensus_filtering_labels'].items():
        all_labels[label] += count

print(f"Total records across all files: {len(all_data)}")

# Overall key counts for percentage calculations
overall_incorrect_count = all_labels.get('o4-mini_incorrect_and_MC_agrees_and_disagrees', 0)
overall_correct_count = all_labels.get('o4-mini_correct_and_MC_agrees', 0)
overall_unused_count = all_labels.get('o4-mini_correct_and_MC_disagrees', 0)

print(f"Overall first incorrect step sections:")
for section, count in sorted(all_sections.items()):
    if overall_incorrect_count > 0:
        pct = (count / overall_incorrect_count) * 100
        print(f"  {section}: {count} ({pct:.1f}% of incorrect samples)")
    else:
        print(f"  {section}: {count} (no incorrect samples)")

print(f"Overall consensus filtering labels:")
for label, count in sorted(all_labels.items()):
    print(f"  {label}: {count}")

# Overall training sample breakdown
overall_training_total = overall_correct_count + overall_incorrect_count
print(f"Overall training samples breakdown:")
print(f"  Used for training: {overall_training_total} ({overall_correct_count} correct + {overall_incorrect_count} incorrect)")
print(f"  Not used for training: {overall_unused_count} (o4-mini_correct_and_MC_disagrees)")

if overall_training_total > 0:
    overall_correct_pct = (overall_correct_count / overall_training_total) * 100
    overall_incorrect_pct = (overall_incorrect_count / overall_training_total) * 100
    print(f"  Training split: {overall_correct_pct:.1f}% correct, {overall_incorrect_pct:.1f}% incorrect")
else:
    print(f"  Training split: No training samples")


Processing 3 files...
Processing InfoVQA_final_mc_rollouts_with_all_models_verification_merged_prm_training_data_mc0.8.jsonl...
  Lines processed: 25557
Processing vqav2_final_mc_rollouts_with_all_models_verification_merged_prm_training_data_mc0.8.jsonl...
  Lines processed: 8169
Processing CLEVR_final_mc_rollouts_with_all_models_verification_merged_prm_training_data_mc0.8.jsonl...
  Lines processed: 27571

Writing flattened data to /mnt/fast10/brandon/mmr_rollout_data/prm_training_data/debug/flattened_all_data.jsonl...
✓ Flattened 61297 total records to /mnt/fast10/brandon/mmr_rollout_data/prm_training_data/debug/flattened_all_data.jsonl

STATISTICS BY FILE

InfoVQA_final_mc_rollouts_with_all_models_verification_merged_prm_training_data_mc0.8.jsonl:
  Total records: 25557
  First incorrect step sections:
    Reasoning: 713 (18.6% of incorrect samples)
    Visual Elements: 3128 (81.4% of incorrect samples)
  Consensus filtering labels:
    o4-mini_correct_and_MC_agrees: 15901
    o4-m