In [None]:
import os
import json
import numpy as np
import random
from tqdm import tqdm

def grid_to_string(grid):
    """Convert a grid (2D array) to a readable string format"""
    return '\n'.join([' '.join([str(cell) for cell in row]) for row in grid])

def grid_to_json_format(grid):
    """Convert a grid to JSON string format (for alternative representation)"""
    return json.dumps(grid)

def format_for_llada(examples, task_id):
    """
    Format ARC examples for LLaDA instruction tuning.
    
    Based on the LLaDA paper, we need to format data as prompt/response pairs.
    Each example in ARC contains an input grid and an output grid.
    """
    formatted_data = []
    
    for i, example in enumerate(examples):
        input_grid = example["input"]
        output_grid = example["output"]
        
        # Format input grid as a string representation
        input_str = grid_to_string(input_grid)
        output_str = grid_to_string(output_grid)
        
        # Create prompt and response
        prompt = f"Task {task_id} Example #{i}:\nGiven the following grid:\n{input_str}\nTransform it according to the pattern."
        response = f"Here is the transformed grid:\n{output_str}"
        
        formatted_data.append({
            "prompt": prompt,
            "response": response
        })
    
    return formatted_data

def split_task(task_id, train_size, val_size, input_dir='gen10000/tasks', output_dir='output'):
    """
    Split a task's examples into training and validation sets for LLaDA training.
    
    Args:
        task_id (str): The task ID (e.g., '0a938d79')
        train_size (int): Number of examples for training set
        val_size (int): Number of examples for validation set
        input_dir (str): Directory containing task JSON files
        output_dir (str): Directory to save output files
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    task_output_dir = os.path.join(output_dir, task_id)
    os.makedirs(task_output_dir, exist_ok=True)
    
    # Load the task data
    input_file = os.path.join(input_dir, f"{task_id}.json")
    print(f"Loading task from {input_file}...")
    
    with open(input_file, 'r') as f:
        examples = json.load(f)
    
    # Make sure we have enough examples
    total_examples = len(examples)
    if total_examples < train_size + val_size:
        print(f"Warning: Requested {train_size + val_size} examples but only {total_examples} available.")
        print(f"Adjusting split proportionally...")
        
        # Adjust proportionally
        ratio = train_size / (train_size + val_size)
        train_size = int(total_examples * ratio)
        val_size = total_examples - train_size
    
    # Shuffle examples
    random.seed(42)  # For reproducibility
    random.shuffle(examples)
    
    # Split the data
    train_examples = examples[:train_size]
    val_examples = examples[train_size:train_size + val_size]
    
    print(f"Split {len(train_examples)} training examples and {len(val_examples)} validation examples")
    
    # Format data for LLaDA training
    train_formatted = format_for_llada(train_examples, task_id)
    val_formatted = format_for_llada(val_examples, task_id)
    
    # Save the formatted data
    train_output = os.path.join(task_output_dir, "train.jsonl")
    val_output = os.path.join(task_output_dir, "val.jsonl")
    
    print(f"Saving training data to {train_output}...")
    with open(train_output, 'w') as f:
        for item in train_formatted:
            f.write(json.dumps(item) + '\n')
    
    print(f"Saving validation data to {val_output}...")
    with open(val_output, 'w') as f:
        for item in val_formatted:
            f.write(json.dumps(item) + '\n')
    
    # Also save a metadata.json file with task information
    metadata = {
        "task_id": task_id,
        "total_examples": total_examples,
        "train_size": len(train_examples),
        "val_size": len(val_examples),
        "format": "JSONL with prompt/response pairs for LLaDA training"
    }
    
    with open(os.path.join(task_output_dir, "metadata.json"), 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"Task {task_id} successfully split and formatted for LLaDA training!")
    
    # Return sample data for examination
    return {
        "train_sample": train_formatted[:2],
        "val_sample": val_formatted[:2],
        "train_path": train_output,
        "val_path": val_output
    }

In [2]:
task_id = '0a938d79'  # Replace with your task ID
test_results = split_task(task_id, train_size=9500, val_size=500, 
                         input_dir='/scratch/sz4972/DiCoRGI/data_gen/gen10000/tasks',
                         output_dir='split_output')

print("SAMPLE TRAINING DATA:")
for i, example in enumerate(test_results["train_sample"]):
    print(f"\nExample {i+1}:")
    print(f"PROMPT:\n{example['prompt']}\n")
    print(f"RESPONSE:\n{example['response']}\n")
    print("-" * 50)

print("SAMPLE VALIDATION DATA:")
for i, example in enumerate(test_results["val_sample"]):
    print(f"\nExample {i+1}:")
    print(f"PROMPT:\n{example['prompt']}\n")
    print(f"RESPONSE:\n{example['response']}\n")
    print("-" * 50)

print("\nVerifying saved files...")

with open(test_results["train_path"], 'r') as f:
    train_lines = f.readlines()
    
with open(test_results["val_path"], 'r') as f:
    val_lines = f.readlines()

print(f"Number of lines in training file: {len(train_lines)}")
print(f"Number of lines in validation file: {len(val_lines)}")

# Parse the first line from each file to check JSON format
try:
    train_first = json.loads(train_lines[0])
    val_first = json.loads(val_lines[0])
    print("JSON parsing successful!")
    print("Files are correctly formatted.")
except json.JSONDecodeError:
    print("ERROR: Files are not correctly formatted JSON.")


Loading task from /scratch/sz4972/DiCoRGI/data_gen/gen10000/tasks/0a938d79.json...
Split 9500 training examples and 500 validation examples
Saving training data to split_output/0a938d79/train.jsonl...
Saving validation data to split_output/0a938d79/val.jsonl...
Task 0a938d79 successfully split and formatted for LLaDA training!
SAMPLE TRAINING DATA:

Example 1:
PROMPT:
Task 0a938d79 Example #0:
Given the following grid:
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 3 4 4 4 4 4 6 4 4 4 4 4 4
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4

Number of lines in training file: 9500
Number of lines in validation file: 500
JSON parsing successful!
Files are correctly formatted.


'\nfull_results = split_task(task_id, train_size=9500, val_size=500, \n                         input_dir=\'/scratch/sz4972/DiCoRGI/data_gen/gen10000/tasks\',\n                         output_dir=\'output\')\nprint("Full split complete!")\n'