In [20]:
import os
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from datasets import load_dataset
import numpy as np
from typing import List
import io

In [21]:
def create_composite_matrix_image(panels: List[Image.Image]) -> Image.Image:
    """
    Create a composite image showing the 8 panels arranged in a 3x3 grid layout
    with borders around each panel, matching the reference style.
    
    Args:
        panels: List of 8 PIL images representing the matrix panels
        
    Returns:
        PIL Image showing the 3x3 matrix with missing bottom-right panel
    """
    # Assume all panels are the same size
    panel_width, panel_height = panels[0].size
    
    # Add border width and spacing
    border_width = 2
    spacing = 4
    
    # Calculate dimensions with borders and spacing
    cell_width = panel_width + 2 * border_width
    cell_height = panel_height + 2 * border_width
    
    # Add margin for the "Problem Matrix" text on the left side
    left_margin = 60  # Margin for the vertical text - increased from 40 to 60
    
    composite_width = cell_width * 3 + spacing * 2 + left_margin
    composite_height = cell_height * 3 + spacing * 2
    
    # Create composite image with white background
    composite = Image.new('RGB', (composite_width, composite_height), 'white')
    
    draw = ImageDraw.Draw(composite)
    
    # Arrange panels in 3x3 grid (missing bottom-right)
    positions = [
        (0, 0),      # Panel 1: top-left
        (1, 0),      # Panel 2: top-center  
        (2, 0),      # Panel 3: top-right
        (0, 1),      # Panel 4: middle-left
        (1, 1),      # Panel 5: middle-center
        (2, 1),      # Panel 6: middle-right
        (0, 2),      # Panel 7: bottom-left
        (1, 2),      # Panel 8: bottom-center
        # (2, 2) is missing - bottom-right
    ]
    
    for i, (col, row) in enumerate(positions):
        # Calculate position with spacing, accounting for left margin
        x = left_margin + col * (cell_width + spacing)
        y = row * (cell_height + spacing)
        
        # Draw border rectangle
        draw.rectangle([x, y, x + cell_width, y + cell_height], 
                      outline='black', fill='white', width=border_width)
        
        # Paste the panel inside the border
        panel_x = x + border_width
        panel_y = y + border_width
        composite.paste(panels[i], (panel_x, panel_y))
    
    # Add question mark for missing panel
    missing_col, missing_row = 2, 2
    missing_x = left_margin + missing_col * (cell_width + spacing)
    missing_y = missing_row * (cell_height + spacing)
    
    # Draw border for missing panel
    draw.rectangle([missing_x, missing_y, missing_x + cell_width, missing_y + cell_height], 
                  outline='black', fill='white', width=border_width)
    
    # Add question mark in the center
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 
                                min(panel_width, panel_height) // 3)
    except:
        try:
            font = ImageFont.load_default()
        except:
            font = None
    
    question_text = "?"
    if font:
        bbox = draw.textbbox((0, 0), question_text, font=font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
    else:
        text_width, text_height = 20, 20
    
    text_x = missing_x + (cell_width - text_width) // 2
    text_y = missing_y + (cell_height - text_height) // 2
    
    draw.text((text_x, text_y), question_text, fill='black', font=font)
    
    # Add "Problem Matrix" text on the left side
    try:
        label_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 18)
    except:
        try:
            label_font = ImageFont.load_default()
        except:
            label_font = None
    
    if label_font:
        problem_matrix_text = "Problem Matrix"
        text_width, text_height = draw.textbbox((0, 0), problem_matrix_text, font=label_font)[2:4]
        
        # Calculate the position for the vertical text
        # Center it vertically in the matrix area
        matrix_height = 3 * cell_height + 2 * spacing
        text_x = 25  # Fixed distance from the left edge - increased from 15 to 25
        text_y = (composite_height - text_width) // 2  # Center vertically
        
        # Draw the rotated text
        # We need to create a temporary image, draw text, then rotate and paste
        text_img = Image.new('RGBA', (text_width, text_height), (255, 255, 255, 0))
        text_draw = ImageDraw.Draw(text_img)
        text_draw.text((0, 0), problem_matrix_text, fill='black', font=label_font)
        
        # Rotate the text image 90 degrees counter-clockwise
        rotated_text = text_img.rotate(90, expand=True)
        
        # Paste the rotated text onto the composite image
        composite.paste(rotated_text, (text_x, text_y), rotated_text)
    
    return composite


In [22]:
def create_choices_grid(choices: List[Image.Image]) -> Image.Image:
    """
    Create a grid showing the 8 answer choices with index labels below each choice,
    matching the reference style with 2 rows of 4 choices each.
    
    Args:
        choices: List of 8 PIL images representing answer choices
        
    Returns:
        PIL Image showing choices arranged in 2x4 grid with labels below
    """
    choice_width, choice_height = choices[0].size
    
    # Grid configuration: 2 rows, 4 columns
    grid_cols, grid_rows = 4, 2
    
    # Add border width and spacing
    border_width = 2
    spacing = 4
    label_height = 30  # Space for number labels below choices
    
    # Calculate cell dimensions
    cell_width = choice_width + 2 * border_width
    cell_height = choice_height + 2 * border_width + label_height
    
    # Calculate total dimensions
    composite_width = cell_width * grid_cols + spacing * (grid_cols - 1)
    composite_height = cell_height * grid_rows + spacing * (grid_rows - 1)
    
    composite = Image.new('RGB', (composite_width, composite_height), 'white')
    
    draw = ImageDraw.Draw(composite)
    
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 18)
    except:
        try:
            font = ImageFont.load_default()
        except:
            font = None
    
    for i, choice in enumerate(choices):
        col = i % grid_cols
        row = i // grid_cols
        
        # Calculate position with spacing
        x = col * (cell_width + spacing)
        y = row * (cell_height + spacing)
        
        # Draw border rectangle around choice
        choice_border_height = choice_height + 2 * border_width
        draw.rectangle([x, y, x + cell_width, y + choice_border_height], 
                      outline='black', fill='white', width=border_width)
        
        # Paste the choice image inside the border
        choice_x = x + border_width
        choice_y = y + border_width
        composite.paste(choice, (choice_x, choice_y))
        
        # Add index label below the choice
        label = str(i + 1)  # 1-indexed for display
        if font:
            bbox = draw.textbbox((0, 0), label, font=font)
            text_width = bbox[2] - bbox[0]
            text_height = bbox[3] - bbox[1]
        else:
            text_width, text_height = 10, 12
        
        # Center the label below the choice
        label_x = x + (cell_width - text_width) // 2
        label_y = y + choice_border_height + (label_height - text_height) // 2
        
        draw.text((label_x, label_y), label, fill='black', font=font)
    
    return composite


In [23]:
def create_combined_image(matrix_composite: Image.Image, choices_composite: Image.Image) -> Image.Image:
    """
    Combine matrix composite and choices composite into a single image matching
    the reference layout with section labels and proper spacing.
    
    Args:
        matrix_composite: PIL Image showing the 3x3 matrix
        choices_composite: PIL Image showing the 8 choices
        
    Returns:
        PIL Image with labeled sections: Problem Matrix and Answer Set
    """
    # Calculate dimensions
    matrix_width, matrix_height = matrix_composite.size
    choices_width, choices_height = choices_composite.size
    
    # Layout parameters
    margin = 30  # Margin around the entire image
    section_spacing = 40  # Space between Problem Matrix and Answer Set
    label_height = 40  # Height for section labels (increased)
    label_margin = 15  # Space between label and content (increased)
    
    # Add margin for the left side labels
    left_margin = 120  # Increased from 90 to 120 for more space for vertical text
    
    # Calculate total dimensions
    combined_width = max(matrix_width, choices_width) + 2 * margin + left_margin
    combined_height = (margin + label_height + label_margin + matrix_height + 
                      section_spacing + label_height + label_margin + 
                      choices_height + margin)
    
    # Create combined image with white background
    combined = Image.new('RGB', (combined_width, combined_height), 'white')
    draw = ImageDraw.Draw(combined)
    
    # Set up fonts - using bold font for both
    try:
        title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
        label_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 18)
    except:
        try:
            title_font = ImageFont.load_default()
            label_font = ImageFont.load_default()
        except:
            title_font = None
            label_font = None
    
    # Current y position
    current_y = margin
    
    # Add "(a)" label at top left
    # if title_font:
        # draw.text((margin, current_y), "(a)", fill='black', font=title_font)
    current_y += label_height + label_margin
    
    # We don't need to add "Problem Matrix" label here since it's now part of the matrix_composite
    
    # Center and paste the matrix
    matrix_x = margin + (combined_width - matrix_width - 2*margin) // 2
    combined.paste(matrix_composite, (matrix_x, current_y))
    current_y += matrix_height + section_spacing
    
    # Add "Answer Set" label - positioned on the left side vertically
    if label_font:
        answer_set_label = "Answer Set"
        # Create a temporary image for the rotated text
        bbox = draw.textbbox((0, 0), answer_set_label, font=label_font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
        
        # Create temporary image for text with transparent background
        text_img = Image.new('RGBA', (text_width + 10, text_height + 10), (255, 255, 255, 0))  # Added padding
        text_draw = ImageDraw.Draw(text_img)
        text_draw.text((5, 5), answer_set_label, fill='black', font=label_font)  # Add some padding inside the text image
        
        # Rotate the text image 90 degrees counter-clockwise
        rotated_text = text_img.rotate(90, expand=True)
        
        # Position the rotated text on the left side, centered vertically with choices grid
        # Ensure text is more centered and not cut off
        label_x = 45  # Increased from 35 to 45 for better spacing
        
        # Calculate better vertical centering based on the actual choices area
        # Make sure the text is fully visible within the choices section
        choices_center_y = current_y + choices_height / 2
        rotated_text_height = rotated_text.size[1]
        label_y = choices_center_y - rotated_text_height / 2
        
        # Ensure the text doesn't go outside the image boundaries
        if label_y < current_y:
            label_y = current_y
        if label_y + rotated_text_height > current_y + choices_height:
            label_y = current_y + choices_height - rotated_text_height
            
        combined.paste(rotated_text, (label_x, int(label_y)), rotated_text)
    
    # Center and paste the choices
    choices_x = margin + (combined_width - choices_width - 2*margin) // 2
    combined.paste(choices_composite, (choices_x, current_y))
    
    return combined


In [24]:
# Load the RAVEN dataset - for all required subsets
subsets = ["center_single", "distribute_four", "distribute_nine"]

# have a different shape
# subsets = ["left_center_single_right_center_single", "up_center_single_down_center_single", "in_center_single_out_center_single", "in_distribute_four_out_center_single"]
datasets = {}

for subset in subsets:
    datasets[subset] = load_dataset("HuggingFaceM4/RAVEN", subset)
    print(f"\nLoaded {subset} dataset with the following splits: {datasets[subset].keys()}")
    print(f"Training examples: {len(datasets[subset]['train'])}")
    print(f"Validation examples: {len(datasets[subset]['validation'])}")



Loaded center_single dataset with the following splits: dict_keys(['train', 'validation', 'test'])
Training examples: 6000
Validation examples: 2000

Loaded distribute_four dataset with the following splits: dict_keys(['train', 'validation', 'test'])
Training examples: 6000
Validation examples: 2000

Loaded distribute_nine dataset with the following splits: dict_keys(['train', 'validation', 'test'])
Training examples: 6000
Validation examples: 2000


In [25]:
# Check the structure of a sample from each dataset
for subset in subsets:
    sample = datasets[subset]['train'][0]
    print(f"\n=== Sample from {subset} ===")
    print("Sample keys:", sample.keys())
    
    # Print shape of panels and choices
    print("Panels length:", len(sample['panels']))
    print("Choices length:", len(sample['choices']))
    print("Target:", sample['target'])
    


=== Sample from center_single ===
Sample keys: dict_keys(['panels', 'choices', 'structure', 'meta_matrix', 'meta_target', 'meta_structure', 'target', 'id', 'metadata'])
Panels length: 8
Choices length: 8
Target: 4

=== Sample from distribute_four ===
Sample keys: dict_keys(['panels', 'choices', 'structure', 'meta_matrix', 'meta_target', 'meta_structure', 'target', 'id', 'metadata'])
Panels length: 8
Choices length: 8
Target: 7

=== Sample from distribute_nine ===
Sample keys: dict_keys(['panels', 'choices', 'structure', 'meta_matrix', 'meta_target', 'meta_structure', 'target', 'id', 'metadata'])
Panels length: 8
Choices length: 8
Target: 6


In [19]:
# Create directory structure for saving images
import os
base_dir = "processed_raven_images"
for subset in subsets:
    subset_dir = os.path.join(base_dir, subset)
    os.makedirs(subset_dir, exist_ok=True)
    os.makedirs(os.path.join(subset_dir, "train"), exist_ok=True)
    os.makedirs(os.path.join(subset_dir, "validation"), exist_ok=True)

In [7]:
# # Process each dataset to create composite images - testing on 3 samples first
# for subset in subsets:
#     print(f"\nProcessing {subset} dataset (testing on 3 samples)...")
    
#     # Process 3 samples from train set
#     train_subset = datasets[subset]['train'].select(range(3))
#     train_subset = train_subset.map(
#         lambda x: {
#             'problem_matrix_image': create_composite_matrix_image(x['panels']),
#             'answer_set_image': create_choices_grid(x['choices'])
#         },
#         # remove_columns=['panels', 'choices']  # Remove original columns since we have composites
#     )
    
#     # Process 3 samples from validation set
#     val_subset = datasets[subset]['validation'].select(range(3))
#     val_subset = val_subset.map(
#         lambda x: {
#             'problem_matrix_image': create_composite_matrix_image(x['panels']),
#             'answer_set_image': create_choices_grid(x['choices'])
#         },
#         # remove_columns=['panels', 'choices']  # Remove original columns since we have composites
#     )
    
#     # Update the dataset in the dictionary (with test samples only)
#     datasets[subset] = {
#         'train': train_subset,
#         'validation': val_subset
#     }
    
#     print(f"Completed processing test samples for {subset} dataset")
#     print(f"Train test set size: {len(train_subset)}")
#     print(f"Validation test set size: {len(val_subset)}")

# # Verify the new structure
# for subset in subsets:
#     sample = datasets[subset]['train'][0]
#     print(f"\n=== Sample from {subset} after processing ===")
#     print("Sample keys:", sample.keys())
#     print("Problem matrix image type:", type(sample['problem_matrix_image']))
#     print("Answer set image type:", type(sample['answer_set_image']))

In [8]:
# # Process and save combined images
# base_dir = "/data/users/brandon/ob1-projects/InternVL/internvl_chat/rollout_generation/preprocessed_prompts/preprocessing_scripts/RAVEN/processed_raven_images"

# for subset in subsets:
#     print(f"\nProcessing and saving combined images for {subset}...")
    
#     # Process train set
#     for idx, sample in enumerate(datasets[subset]['train']):
#         combined_image = create_combined_image(
#             sample['problem_matrix_image'],
#             sample['answer_set_image']
#         )
#         save_path = os.path.join(base_dir, subset, "train", f"{sample['id']}.png")
#         combined_image.save(save_path)
#         if idx < 3:  # Print progress for first 3 samples
#             print(f"Saved train image {idx+1}/3: {save_path}")
    
#     # Process validation set
#     for idx, sample in enumerate(datasets[subset]['validation']):
#         combined_image = create_combined_image(
#             sample['problem_matrix_image'],
#             sample['answer_set_image']
#         )
#         save_path = os.path.join(base_dir, subset, "validation", f"{sample['id']}.png")
#         combined_image.save(save_path)
#         if idx < 3:  # Print progress for first 3 samples
#             print(f"Saved validation image {idx+1}/3: {save_path}")

# print("\nCompleted saving all combined images!")

In [9]:
# # Process and save combined images for full dataset with quality checks
# base_dir = "/data/users/brandon/ob1-projects/InternVL/internvl_chat/rollout_generation/preprocessed_prompts/preprocessing_scripts/RAVEN/processed_raven_images"

# def verify_image_count(directory, expected_count):
#     """Verify the number of images in a directory matches expected count"""
#     actual_count = len([f for f in os.listdir(directory) if f.endswith('.png')])
#     if actual_count != expected_count:
#         raise ValueError(f"Image count mismatch in {directory}. Expected {expected_count}, got {actual_count}")
#     return actual_count

# def verify_image_quality(image_path):
#     """Verify image can be opened and has valid dimensions"""
#     try:
#         img = Image.open(image_path)
#         width, height = img.size
#         if width < 100 or height < 100:  # Basic size check
#             raise ValueError(f"Image {image_path} has invalid dimensions: {width}x{height}")
#         return True
#     except Exception as e:
#         raise ValueError(f"Failed to verify image {image_path}: {str(e)}")

# for subset in subsets:
#     print(f"\nProcessing {subset} dataset...")
    
#     # Get expected counts
#     expected_train = len(datasets[subset]['train'])
#     expected_val = len(datasets[subset]['validation'])
#     print(f"Expected counts - Train: {expected_train}, Validation: {expected_val}")
    
#     # Process full train set
#     print(f"Processing train set ({expected_train} samples)...")
#     train_dataset = datasets[subset]['train'].map(
#         lambda x: {
#             'problem_matrix_image': create_composite_matrix_image(x['panels']),
#             'answer_set_image': create_choices_grid(x['choices'])
#         },
#         remove_columns=['panels', 'choices']
#     )
    
#     # Process full validation set
#     print(f"Processing validation set ({expected_val} samples)...")
#     val_dataset = datasets[subset]['validation'].map(
#         lambda x: {
#             'problem_matrix_image': create_composite_matrix_image(x['panels']),
#             'answer_set_image': create_choices_grid(x['choices'])
#         },
#         remove_columns=['panels', 'choices']
#     )
    
#     # Create directories if they don't exist
#     train_dir = os.path.join(base_dir, subset, "train")
#     val_dir = os.path.join(base_dir, subset, "validation")
#     os.makedirs(train_dir, exist_ok=True)
#     os.makedirs(val_dir, exist_ok=True)
    
#     # Save train images
#     print("Saving train images...")
#     train_ids = set()  # Track unique IDs
#     for idx, sample in enumerate(train_dataset):
#         if sample['id'] in train_ids:
#             raise ValueError(f"Duplicate ID found in train set: {sample['id']}")
#         train_ids.add(sample['id'])
        
#         combined_image = create_combined_image(
#             sample['problem_matrix_image'],
#             sample['answer_set_image']
#         )
#         save_path = os.path.join(train_dir, f"{sample['id']}.png")
#         combined_image.save(save_path)
        
#         # Verify image quality
#         verify_image_quality(save_path)
        
#         if idx % 100 == 0:
#             print(f"Saved {idx}/{expected_train} train images")
    
#     # Save validation images
#     print("Saving validation images...")
#     val_ids = set()  # Track unique IDs
#     for idx, sample in enumerate(val_dataset):
#         if sample['id'] in val_ids:
#             raise ValueError(f"Duplicate ID found in validation set: {sample['id']}")
#         val_ids.add(sample['id'])
        
#         combined_image = create_combined_image(
#             sample['problem_matrix_image'],
#             sample['answer_set_image']
#         )
#         save_path = os.path.join(val_dir, f"{sample['id']}.png")
#         combined_image.save(save_path)
        
#         # Verify image quality
#         verify_image_quality(save_path)
        
#         if idx % 100 == 0:
#             print(f"Saved {idx}/{expected_val} validation images")
    
#     # Verify final counts
#     print("\nVerifying final counts...")
#     actual_train = verify_image_count(train_dir, expected_train)
#     actual_val = verify_image_count(val_dir, expected_val)
#     print(f"Verified counts for {subset}:")
#     print(f"Train: {actual_train}/{expected_train}")
#     print(f"Validation: {actual_val}/{expected_val}")
    
#     # Verify no overlap between train and validation IDs
#     if train_ids.intersection(val_ids):
#         raise ValueError(f"Found overlapping IDs between train and validation sets in {subset}")
    
#     print(f"Completed processing {subset} dataset with all quality checks passed")

# print("\nCompleted processing all datasets with quality checks!")

In [26]:
# Process and save combined images for full dataset with quality checks and parallelization
base_dir = "/data/users/brandon/ob1-projects/InternVL/internvl_chat/rollout_generation/preprocessed_prompts/preprocessing_scripts/RAVEN/processed_raven_images"

def verify_image_count(directory, expected_count):
    """Verify the number of images in a directory matches expected count"""
    actual_count = len([f for f in os.listdir(directory) if f.endswith('.png')])
    if actual_count != expected_count:
        raise ValueError(f"Image count mismatch in {directory}. Expected {expected_count}, got {actual_count}")
    return actual_count

def verify_image_quality(image_path):
    """Verify image can be opened and has valid dimensions"""
    try:
        img = Image.open(image_path)
        width, height = img.size
        if width < 100 or height < 100:  # Basic size check
            raise ValueError(f"Image {image_path} has invalid dimensions: {width}x{height}")
        return True
    except Exception as e:
        raise ValueError(f"Failed to verify image {image_path}: {str(e)}")

def process_and_save_batch(batch, subset, split, base_dir):
    """Process a batch of samples and save their images"""
    # Create the correct directory based on split (train/validation)
    save_dir = os.path.join(base_dir, subset, split)
    os.makedirs(save_dir, exist_ok=True)
    
    results = []
    for sample in batch:
        # Create composite images
        problem_matrix = create_composite_matrix_image(sample['panels'])
        answer_set = create_choices_grid(sample['choices'])
        
        # Create and save combined image
        combined_image = create_combined_image(problem_matrix, answer_set)
        save_path = os.path.join(save_dir, f"{sample['id']}.png")
        combined_image.save(save_path)
        
        # Verify image quality
        verify_image_quality(save_path)
        
        results.append({
            'id': sample['id'],
            'path': save_path
        })
    
    return results

for subset in subsets:
    print(f"\nProcessing {subset} dataset...")
    
    # Get expected counts
    expected_train = len(datasets[subset]['train'])
    expected_val = len(datasets[subset]['validation'])
    print(f"Expected counts - Train: {expected_train}, Validation: {expected_val}")
    
    # Process train set with parallelization
    print(f"Processing train set ({expected_train} samples)...")
    train_results = datasets[subset]['train'].map(
        lambda x: process_and_save_batch([x], subset, "train", base_dir)[0],
        num_proc=os.cpu_count(),  # Use all available CPU cores
        batch_size=32,  # Process in batches for better efficiency
        desc="Processing train set",
        remove_columns=['panels', 'choices']
    )
    
    # Process validation set with parallelization
    print(f"Processing validation set ({expected_val} samples)...")
    val_results = datasets[subset]['validation'].map(
        lambda x: process_and_save_batch([x], subset, "validation", base_dir)[0],
        num_proc=os.cpu_count(),  # Use all available CPU cores
        batch_size=32,  # Process in batches for better efficiency
        desc="Processing validation set",
        remove_columns=['panels', 'choices']
    )
    
    # Collect all IDs for verification
    train_ids = set(result['id'] for result in train_results)
    val_ids = set(result['id'] for result in val_results)
    
    # Verify final counts
    print("\nVerifying final counts...")
    actual_train = verify_image_count(os.path.join(base_dir, subset, "train"), expected_train)
    actual_val = verify_image_count(os.path.join(base_dir, subset, "validation"), expected_val)
    print(f"Verified counts for {subset}:")
    print(f"Train: {actual_train}/{expected_train}")
    print(f"Validation: {actual_val}/{expected_val}")
    
    # Verify no overlap between train and validation IDs
    if train_ids.intersection(val_ids):
        raise ValueError(f"Found overlapping IDs between train and validation sets in {subset}")
    
    print(f"Completed processing {subset} dataset with all quality checks passed")

print("\nCompleted processing all datasets with quality checks!")


Processing center_single dataset...
Expected counts - Train: 6000, Validation: 2000
Processing train set (6000 samples)...


Processing train set (num_proc=192):   0%|          | 0/6000 [00:00<?, ? examples/s]

Processing validation set (2000 samples)...


Processing validation set (num_proc=192):   0%|          | 0/2000 [00:00<?, ? examples/s]


Verifying final counts...
Verified counts for center_single:
Train: 6000/6000
Validation: 2000/2000
Completed processing center_single dataset with all quality checks passed

Processing distribute_four dataset...
Expected counts - Train: 6000, Validation: 2000
Processing train set (6000 samples)...


Processing train set (num_proc=192):   0%|          | 0/6000 [00:00<?, ? examples/s]

Processing validation set (2000 samples)...


Processing validation set (num_proc=192):   0%|          | 0/2000 [00:00<?, ? examples/s]


Verifying final counts...
Verified counts for distribute_four:
Train: 6000/6000
Validation: 2000/2000
Completed processing distribute_four dataset with all quality checks passed

Processing distribute_nine dataset...
Expected counts - Train: 6000, Validation: 2000
Processing train set (6000 samples)...


Processing train set (num_proc=192):   0%|          | 0/6000 [00:00<?, ? examples/s]

Processing validation set (2000 samples)...


Processing validation set (num_proc=192):   0%|          | 0/2000 [00:00<?, ? examples/s]


Verifying final counts...
Verified counts for distribute_nine:
Train: 6000/6000
Validation: 2000/2000
Completed processing distribute_nine dataset with all quality checks passed

Completed processing all datasets with quality checks!


In [27]:
import pandas as pd
import json
import os
from multiprocessing import Pool, cpu_count
from functools import partial
from tqdm import tqdm

def process_subset_split(args):
    """Process a single subset-split combination"""
    subset, split, datasets = args
    data = []
    dataset = datasets[subset][split]
    
    # Process each item in the dataset
    for item in dataset:
        data.append({
            'id': item['id'],
            'combined_image_path': f"/processed_raven_images/{subset}/{split}/{item['id']}.png",
            'correct_answer': item['target'] + 1,
            'subset_split': f"{subset}_{split}"
        })
    
    # Create DataFrame for this subset_split
    df = pd.DataFrame(data)
    
    # Save as JSONL
    output_dir = "raven_processed_jsonl"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"{subset}_{split}.jsonl")
    
    with open(output_file, 'w') as f:
        for _, row in df.iterrows():
            f.write(json.dumps(row.to_dict()) + '\n')
    
    return {
        'subset_split': f"{subset}_{split}",
        'count': len(df),
        'file': output_file,
        'sample': df.head(2).to_dict('records')
    }

# Create list of all subset-split combinations to process
subset_split_combinations = [(subset, split, datasets) 
                           for subset in subsets 
                           for split in ['train', 'validation']]

# Use all available CPU cores
num_cores = cpu_count()
print(f"\nProcessing using {num_cores} CPU cores")

# Process in parallel with progress bar
with Pool(num_cores) as pool:
    results = list(tqdm(
        pool.imap(process_subset_split, subset_split_combinations),
        total=len(subset_split_combinations),
        desc="Processing subset-split combinations"
    ))

# Print results
print("\n=== Processing Complete ===")
for result in results:
    print(f"\n{result['subset_split']}:")
    print(f"Number of examples: {result['count']}")
    print(f"Saved to: {result['file']}")
    print("Sample entries:")
    print(json.dumps(result['sample'], indent=2))


Processing using 192 CPU cores


Processing subset-split combinations: 100%|██████████| 6/6 [00:13<00:00,  2.28s/it]


=== Processing Complete ===

center_single_train:
Number of examples: 6000
Saved to: raven_processed_jsonl/center_single_train.jsonl
Sample entries:
[
  {
    "id": 3023,
    "combined_image_path": "/processed_raven_images/center_single/train/3023.png",
    "correct_answer": 5,
    "subset_split": "center_single_train"
  },
  {
    "id": 214,
    "combined_image_path": "/processed_raven_images/center_single/train/214.png",
    "correct_answer": 5,
    "subset_split": "center_single_train"
  }
]

center_single_validation:
Number of examples: 2000
Saved to: raven_processed_jsonl/center_single_validation.jsonl
Sample entries:
[
  {
    "id": 5047,
    "combined_image_path": "/processed_raven_images/center_single/validation/5047.png",
    "correct_answer": 5,
    "subset_split": "center_single_validation"
  },
  {
    "id": 9356,
    "combined_image_path": "/processed_raven_images/center_single/validation/9356.png",
    "correct_answer": 8,
    "subset_split": "center_single_validation"
 




In [29]:
import pandas as pd
import json
import os
from multiprocessing import Pool, cpu_count
from functools import partial
from tqdm import tqdm
from PIL import Image
import numpy as np

def process_subset_split(args):
    """Process a single subset-split combination"""
    subset, split, datasets = args
    data = []
    dataset = datasets[subset][split]
    
    # Process each item in the dataset
    for item in dataset:
        data.append({
            'id': item['id'],
            'combined_image_path': f"/processed_raven_images/{subset}/{split}/{item['id']}.png",
            'correct_answer': item['target'] + 1,
            'subset_split': f"{subset}_{split}"
        })
    
    # Create DataFrame for this subset_split
    df = pd.DataFrame(data)
    
    # Save as JSONL
    output_dir = "raven_processed_jsonl"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"{subset}_{split}.jsonl")
    
    with open(output_file, 'w') as f:
        for _, row in df.iterrows():
            f.write(json.dumps(row.to_dict()) + '\n')
    
    return {
        'subset_split': f"{subset}_{split}",
        'count': len(df),
        'file': output_file,
        'sample': df.head(2).to_dict('records')
    }

# Create list of all subset-split combinations to process
subset_split_combinations = [(subset, split, datasets) 
                           for subset in subsets 
                           for split in ['train', 'validation']]

# Use all available CPU cores
num_cores = cpu_count()
print(f"\nProcessing using {num_cores} CPU cores")

# Process in parallel with progress bar
with Pool(num_cores) as pool:
    results = list(tqdm(
        pool.imap(process_subset_split, subset_split_combinations),
        total=len(subset_split_combinations),
        desc="Processing subset-split combinations"
    ))

# Print results
print("\n=== Processing Complete ===")
for result in results:
    print(f"\n{result['subset_split']}:")
    print(f"Number of examples: {result['count']}")
    print(f"Saved to: {result['file']}")
    print("Sample entries:")
    print(json.dumps(result['sample'], indent=2))

# Validation code
print("\n=== Starting Validation ===")

def validate_jsonl_file(file_path, original_dataset, subset_split):
    """Validate a single JSONL file"""
    validation_results = {
        'file': file_path,
        'subset_split': subset_split,
        'errors': [],
        'warnings': []
    }
    
    # Read all lines from JSONL
    with open(file_path, 'r') as f:
        lines = f.readlines()
    
    # Convert to list of dicts
    data = [json.loads(line) for line in lines]
    
    # 1. Check count matches original dataset
    original_count = len(original_dataset)
    if len(data) != original_count:
        validation_results['errors'].append(
            f"Count mismatch: JSONL has {len(data)} items, original has {original_count}"
        )
    
    # 2. Check image paths and load images
    valid_image_count = 0
    for item in data:
        try:
            img = Image.open(item['combined_image_path'])
            img.verify()  # Verify it's a valid image
            valid_image_count += 1
        except Exception as e:
            validation_results['errors'].append(
                f"Invalid image path or corrupted image for ID {item['id']}: {str(e)}"
            )
    
    if valid_image_count != len(data):
        validation_results['warnings'].append(
            f"Only {valid_image_count}/{len(data)} images could be loaded"
        )
    
    # 3. Check correct_answer values
    valid_answers = set(range(1, 9))  # 1-8
    invalid_answers = [item['id'] for item in data if item['correct_answer'] not in valid_answers]
    if invalid_answers:
        validation_results['errors'].append(
            f"Invalid correct_answer values found for IDs: {invalid_answers}"
        )
    
    # 4. Check IDs are unique within this file
    ids = [item['id'] for item in data]
    if len(ids) != len(set(ids)):
        validation_results['errors'].append(
            f"Duplicate IDs found in {subset_split}"
        )
    
    return validation_results, ids

# Validate each subset and split combination
print("\nValidating files...")
all_ids = {}  # Dict to store IDs for each subset_split

for subset in subsets:
    for split in ['train', 'validation']:
        subset_split = f"{subset}_{split}"
        file_path = os.path.join("raven_processed_jsonl", f"{subset_split}.jsonl")
        
        # Validate the file
        validation, ids = validate_jsonl_file(
            file_path,
            datasets[subset][split],
            subset_split
        )
        
        # Store IDs for this subset_split
        all_ids[subset_split] = set(ids)
        
        # Print validation results
        print(f"\n=== Validation results for {subset_split} ===")
        if validation['errors']:
            print("Errors:")
            for error in validation['errors']:
                print(f"  - {error}")
        if validation['warnings']:
            print("Warnings:")
            for warning in validation['warnings']:
                print(f"  - {warning}")

# Check for ID overlaps between different subset_splits
print("\n=== Checking for ID overlaps between subset_splits ===")
for subset1 in subsets:
    for split1 in ['train', 'validation']:
        subset_split1 = f"{subset1}_{split1}"
        for subset2 in subsets:
            for split2 in ['train', 'validation']:
                subset_split2 = f"{subset2}_{split2}"
                if subset_split1 < subset_split2:  # Avoid checking same pair twice
                    overlap = all_ids[subset_split1].intersection(all_ids[subset_split2])
                    if overlap:
                        print(f"Found {len(overlap)} overlapping IDs between {subset_split1} and {subset_split2}")
                        print(f"Sample overlapping IDs: {list(overlap)[:5]}")

# Final summary
print("\n=== Final Validation Summary ===")
total_expected = sum(len(datasets[subset][split]) for subset in subsets for split in ['train', 'validation'])
total_actual = sum(len(ids) for ids in all_ids.values())
print(f"Total expected IDs: {total_expected}")
print(f"Total actual unique IDs: {total_actual}")

if total_expected == total_actual:
    print("✓ Total ID count matches expected")
else:
    print(f"✗ Total ID count mismatch: expected {total_expected}, got {total_actual}")


Processing using 192 CPU cores


Processing subset-split combinations: 100%|██████████| 6/6 [00:13<00:00,  2.27s/it]



=== Processing Complete ===

center_single_train:
Number of examples: 6000
Saved to: raven_processed_jsonl/center_single_train.jsonl
Sample entries:
[
  {
    "id": 3023,
    "combined_image_path": "/processed_raven_images/center_single/train/3023.png",
    "correct_answer": 5,
    "subset_split": "center_single_train"
  },
  {
    "id": 214,
    "combined_image_path": "/processed_raven_images/center_single/train/214.png",
    "correct_answer": 5,
    "subset_split": "center_single_train"
  }
]

center_single_validation:
Number of examples: 2000
Saved to: raven_processed_jsonl/center_single_validation.jsonl
Sample entries:
[
  {
    "id": 5047,
    "combined_image_path": "/processed_raven_images/center_single/validation/5047.png",
    "correct_answer": 5,
    "subset_split": "center_single_validation"
  },
  {
    "id": 9356,
    "combined_image_path": "/processed_raven_images/center_single/validation/9356.png",
    "correct_answer": 8,
    "subset_split": "center_single_validation"
 

In [30]:
import pandas as pd
import json
import os
from multiprocessing import Pool, cpu_count
from functools import partial
from tqdm import tqdm
from PIL import Image
import numpy as np

# Define the base path for processed images
BASE_IMAGE_PATH = "/data/users/brandon/ob1-projects/InternVL/internvl_chat/rollout_generation/preprocessed_prompts/preprocessing_scripts/RAVEN/processed_raven_images"

def process_subset_split(args):
    """Process a single subset-split combination"""
    subset, split, datasets = args
    data = []
    dataset = datasets[subset][split]
    
    # Process each item in the dataset
    for item in dataset:
        data.append({
            'id': item['id'],
            'combined_image_path': os.path.join(BASE_IMAGE_PATH, subset, split, f"{item['id']}.png"),
            'correct_answer': item['target'] + 1,
            'subset_split': f"{subset}_{split}"
        })
    
    # Create DataFrame for this subset_split
    df = pd.DataFrame(data)
    
    # Save as JSONL
    output_dir = "raven_processed_jsonl"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"{subset}_{split}.jsonl")
    
    with open(output_file, 'w') as f:
        for _, row in df.iterrows():
            f.write(json.dumps(row.to_dict()) + '\n')
    
    return {
        'subset_split': f"{subset}_{split}",
        'count': len(df),
        'file': output_file,
        'sample': df.head(2).to_dict('records')
    }

# Create list of all subset-split combinations to process
subset_split_combinations = [(subset, split, datasets) 
                           for subset in subsets 
                           for split in ['train', 'validation']]

# Use all available CPU cores
num_cores = cpu_count()
print(f"\nProcessing using {num_cores} CPU cores")

# Process in parallel with progress bar
with Pool(num_cores) as pool:
    results = list(tqdm(
        pool.imap(process_subset_split, subset_split_combinations),
        total=len(subset_split_combinations),
        desc="Processing subset-split combinations"
    ))

# Print results
print("\n=== Processing Complete ===")
for result in results:
    print(f"\n{result['subset_split']}:")
    print(f"Number of examples: {result['count']}")
    print(f"Saved to: {result['file']}")
    print("Sample entries:")
    print(json.dumps(result['sample'], indent=2))


Processing using 192 CPU cores


Processing subset-split combinations: 100%|██████████| 6/6 [00:13<00:00,  2.27s/it]



=== Processing Complete ===

center_single_train:
Number of examples: 6000
Saved to: raven_processed_jsonl/center_single_train.jsonl
Sample entries:
[
  {
    "id": 3023,
    "combined_image_path": "/data/users/brandon/ob1-projects/InternVL/internvl_chat/rollout_generation/preprocessed_prompts/preprocessing_scripts/RAVEN/processed_raven_images/center_single/train/3023.png",
    "correct_answer": 5,
    "subset_split": "center_single_train"
  },
  {
    "id": 214,
    "combined_image_path": "/data/users/brandon/ob1-projects/InternVL/internvl_chat/rollout_generation/preprocessed_prompts/preprocessing_scripts/RAVEN/processed_raven_images/center_single/train/214.png",
    "correct_answer": 5,
    "subset_split": "center_single_train"
  }
]

center_single_validation:
Number of examples: 2000
Saved to: raven_processed_jsonl/center_single_validation.jsonl
Sample entries:
[
  {
    "id": 5047,
    "combined_image_path": "/data/users/brandon/ob1-projects/InternVL/internvl_chat/rollout_generati

In [32]:
# Validation code
print("\n=== Starting Validation ===")

def validate_jsonl_file(file_path, original_dataset, subset_split):
    """Validate a single JSONL file"""
    validation_results = {
        'file': file_path,
        'subset_split': subset_split,
        'errors': [],
        'warnings': []
    }
    
    # Read all lines from JSONL
    with open(file_path, 'r') as f:
        lines = f.readlines()
    
    # Convert to list of dicts
    data = [json.loads(line) for line in lines]
    
    # 1. Check count matches original dataset
    original_count = len(original_dataset)
    if len(data) != original_count:
        validation_results['errors'].append(
            f"Count mismatch: JSONL has {len(data)} items, original has {original_count}"
        )
    
    # 2. Check image paths and load images
    valid_image_count = 0
    for item in data:
        try:
            if not os.path.exists(item['combined_image_path']):
                validation_results['errors'].append(
                    f"Image file does not exist: {item['combined_image_path']}"
                )
                continue
                
            img = Image.open(item['combined_image_path'])
            img.verify()  # Verify it's a valid image
            valid_image_count += 1
        except Exception as e:
            validation_results['errors'].append(
                f"Invalid image path or corrupted image for ID {item['id']}: {str(e)}"
            )
    
    if valid_image_count != len(data):
        validation_results['warnings'].append(
            f"Only {valid_image_count}/{len(data)} images could be loaded"
        )
    
    # 3. Check correct_answer values
    valid_answers = set(range(1, 9))  # 1-8
    invalid_answers = [item['id'] for item in data if item['correct_answer'] not in valid_answers]
    if invalid_answers:
        validation_results['errors'].append(
            f"Invalid correct_answer values found for IDs: {invalid_answers}"
        )
    
    # 4. Check IDs are unique within this file
    ids = [item['id'] for item in data]
    if len(ids) != len(set(ids)):
        validation_results['errors'].append(
            f"Duplicate IDs found in {subset_split}"
        )
    
    return validation_results, ids

# Validate each subset and split combination
print("\nValidating files...")
all_ids = {}  # Dict to store IDs for each subset_split

for subset in subsets:
    for split in ['train', 'validation']:
        subset_split = f"{subset}_{split}"
        file_path = os.path.join("raven_processed_jsonl", f"{subset_split}.jsonl")
        
        # Validate the file
        validation, ids = validate_jsonl_file(
            file_path,
            datasets[subset][split],
            subset_split
        )
        
        # Store IDs for this subset_split
        all_ids[subset_split] = set(ids)
        
        # Print validation results
        print(f"\n=== Validation results for {subset_split} ===")
        if validation['errors']:
            print("Errors:")
            for error in validation['errors']:
                print(f"  - {error}")
        else:
            print("✓ No errors found")
        if validation['warnings']:
            print("Warnings:")
            for warning in validation['warnings']:
                print(f"  - {warning}")
        else:
            print("✓ No warnings found")

# Check for ID overlaps between train and validation splits within each subset
print("\n=== Checking for ID overlaps between train and validation splits ===")
for subset in subsets:
    train_ids = all_ids[f"{subset}_train"]
    val_ids = all_ids[f"{subset}_validation"]
    overlap = train_ids.intersection(val_ids)
    if overlap:
        print(f"Found {len(overlap)} overlapping IDs between {subset}_train and {subset}_validation")
        print(f"Sample overlapping IDs: {list(overlap)[:5]}")
    else:
        print(f"✓ No overlapping IDs found between {subset}_train and {subset}_validation")

# Final summary
print("\n=== Final Validation Summary ===")
total_expected = sum(len(datasets[subset][split]) for subset in subsets for split in ['train', 'validation'])
total_actual = sum(len(ids) for ids in all_ids.values())
print(f"Total expected IDs: {total_expected}")
print(f"Total actual unique IDs: {total_actual}")

if total_expected == total_actual:
    print("✓ Total ID count matches expected")
else:
    print(f"✗ Total ID count mismatch: expected {total_expected}, got {total_actual}")


=== Starting Validation ===

Validating files...

=== Validation results for center_single_train ===
✓ No errors found

=== Validation results for center_single_validation ===
✓ No errors found

=== Validation results for distribute_four_train ===
✓ No errors found

=== Validation results for distribute_four_validation ===
✓ No errors found

=== Validation results for distribute_nine_train ===
✓ No errors found

=== Validation results for distribute_nine_validation ===
✓ No errors found

=== Checking for ID overlaps between train and validation splits ===
✓ No overlapping IDs found between center_single_train and center_single_validation
✓ No overlapping IDs found between distribute_four_train and distribute_four_validation
✓ No overlapping IDs found between distribute_nine_train and distribute_nine_validation

=== Final Validation Summary ===
Total expected IDs: 24000
Total actual unique IDs: 24000
✓ Total ID count matches expected
