In [2]:
import json
import os
import glob
from pathlib import Path
from typing import List, Dict, Any

def check_schema_consistency(jsonl_files: List[str]) -> bool:
    """Check if all JSONL files have identical schema."""
    if not jsonl_files:
        return True
    
    # Read first file to get reference schema
    with open(jsonl_files[0], 'r') as f:
        first_line = f.readline().strip()
        if not first_line:
            return True
        reference_schema = set(json.loads(first_line).keys())
    
    # Check all other files
    for file_path in jsonl_files[1:]:
        with open(file_path, 'r') as f:
            first_line = f.readline().strip()
            if not first_line:
                continue
            current_schema = set(json.loads(first_line).keys())
            if current_schema != reference_schema:
                print(f"Schema mismatch in {file_path}")
                print(f"Expected: {reference_schema}")
                print(f"Found: {current_schema}")
                return False
    
    return True

def flatten_jsonl_files(input_files: List[str], output_file: str) -> int:
    """Flatten multiple JSONL files into a single JSONL file."""
    total_lines = 0
    
    with open(output_file, 'w') as outfile:
        for file_path in input_files:
            print(f"Processing {file_path}...")
            with open(file_path, 'r') as infile:
                for line in infile:
                    line = line.strip()
                    if line:  # Skip empty lines
                        outfile.write(line + '\n')
                        total_lines += 1
    
    return total_lines

In [8]:
splits = ["RAVEN"]
# splits = ["AI2D", "CLEVR_10K", "RAVEN"]
output_dir = "/mnt/fast10/brandon/mmr_rollout_data/flattened_rollout_files"

# Process each split
for split in splits:
    print(f"\n=== Processing split: {split} ===")
    
    # Find all JSONL files in the split directory
    split_dir = f"./raw_rollouts/rollouts_only/{split}"
    if split == "RAVEN":
        jsonl_pattern = os.path.join(split_dir, "**", "*.jsonl")
        jsonl_files = glob.glob(jsonl_pattern, recursive=True)
    else:
        jsonl_pattern = os.path.join(split_dir, "*.jsonl")
        jsonl_files = glob.glob(jsonl_pattern)
    
    if not jsonl_files:
        print(f"No JSONL files found in {split_dir}")
        continue
    
    print(f"Found {len(jsonl_files)} JSONL files:")
    for file_path in jsonl_files:
        print(f"  - {file_path}")
    
    # Check schema consistency
    print(f"\nChecking schema consistency...")
    if not check_schema_consistency(jsonl_files):
        print(f"ERROR: Schema mismatch detected in {split}. Skipping flattening.")
        continue
    
    print("Schema check passed!")
 
    # Flatten files
    output_file = f"{output_dir}/{split}_flattened.jsonl"
    print(f"Flattening to {output_file}...")
    
    total_lines = flatten_jsonl_files(jsonl_files, output_file)
    print(f"Successfully flattened {total_lines} lines to {output_file}")

print("\n=== All splits processed ===")


=== Processing split: RAVEN ===
Found 54 JSONL files:
  - ./raw_rollouts/rollouts_only/RAVEN/distribute_four/distribute_four_validation_raven_rollouts_5000_9999_streaming.jsonl
  - ./raw_rollouts/rollouts_only/RAVEN/distribute_four/distribute_four_validation_raven_rollouts_0_4999_streaming.jsonl
  - ./raw_rollouts/rollouts_only/RAVEN/distribute_four/distribute_four_train_raven_rollouts_2000_3999_streaming.jsonl
  - ./raw_rollouts/rollouts_only/RAVEN/distribute_four/distribute_four_train_raven_rollouts_8000_9999_streaming.jsonl
  - ./raw_rollouts/rollouts_only/RAVEN/distribute_four/distribute_four_train_raven_rollouts_0_1999_streaming.jsonl
  - ./raw_rollouts/rollouts_only/RAVEN/distribute_four/distribute_four_train_raven_rollouts_6000_7999_streaming.jsonl
  - ./raw_rollouts/rollouts_only/RAVEN/distribute_four/distribute_four_train_raven_rollouts_4000_5999_streaming.jsonl
  - ./raw_rollouts/rollouts_only/RAVEN/distribute_nine/distribute_nine_train_raven_rollouts_2000_3999_streaming.jso

In [4]:
# Verify the flattened files
for split in splits:
    output_file = f"{output_dir}/{split}_flattened.jsonl"
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            line_count = sum(1 for line in f if line.strip())
        print(f"{split}_flattened.jsonl: {line_count} lines")