In [6]:
import os
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from datasets import load_dataset
import numpy as np
from typing import List
import io

In [7]:
def create_composite_matrix_image(panels: List[Image.Image]) -> Image.Image:
    """
    Create a composite image showing the 8 panels arranged in a 3x3 grid layout
    with borders around each panel, matching the reference style.
    
    Args:
        panels: List of 8 PIL images representing the matrix panels
        
    Returns:
        PIL Image showing the 3x3 matrix with missing bottom-right panel
    """
    # Assume all panels are the same size
    panel_width, panel_height = panels[0].size
    
    # Add border width and spacing
    border_width = 2
    spacing = 4
    
    # Calculate dimensions with borders and spacing
    cell_width = panel_width + 2 * border_width
    cell_height = panel_height + 2 * border_width
    
    # Add margin for the "Problem Matrix" text on the left side
    left_margin = 60  # Margin for the vertical text - increased from 40 to 60
    
    composite_width = cell_width * 3 + spacing * 2 + left_margin
    composite_height = cell_height * 3 + spacing * 2
    
    # Create composite image with white background
    composite = Image.new('RGB', (composite_width, composite_height), 'white')
    
    draw = ImageDraw.Draw(composite)
    
    # Arrange panels in 3x3 grid (missing bottom-right)
    positions = [
        (0, 0),      # Panel 1: top-left
        (1, 0),      # Panel 2: top-center  
        (2, 0),      # Panel 3: top-right
        (0, 1),      # Panel 4: middle-left
        (1, 1),      # Panel 5: middle-center
        (2, 1),      # Panel 6: middle-right
        (0, 2),      # Panel 7: bottom-left
        (1, 2),      # Panel 8: bottom-center
        # (2, 2) is missing - bottom-right
    ]
    
    for i, (col, row) in enumerate(positions):
        # Calculate position with spacing, accounting for left margin
        x = left_margin + col * (cell_width + spacing)
        y = row * (cell_height + spacing)
        
        # Draw border rectangle
        draw.rectangle([x, y, x + cell_width, y + cell_height], 
                      outline='black', fill='white', width=border_width)
        
        # Paste the panel inside the border
        panel_x = x + border_width
        panel_y = y + border_width
        composite.paste(panels[i], (panel_x, panel_y))
    
    # Add question mark for missing panel
    missing_col, missing_row = 2, 2
    missing_x = left_margin + missing_col * (cell_width + spacing)
    missing_y = missing_row * (cell_height + spacing)
    
    # Draw border for missing panel
    draw.rectangle([missing_x, missing_y, missing_x + cell_width, missing_y + cell_height], 
                  outline='black', fill='white', width=border_width)
    
    # Add question mark in the center
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 
                                min(panel_width, panel_height) // 3)
    except:
        try:
            font = ImageFont.load_default()
        except:
            font = None
    
    question_text = "?"
    if font:
        bbox = draw.textbbox((0, 0), question_text, font=font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
    else:
        text_width, text_height = 20, 20
    
    text_x = missing_x + (cell_width - text_width) // 2
    text_y = missing_y + (cell_height - text_height) // 2
    
    draw.text((text_x, text_y), question_text, fill='black', font=font)
    
    # Add "Problem Matrix" text on the left side
    try:
        label_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 18)
    except:
        try:
            label_font = ImageFont.load_default()
        except:
            label_font = None
    
    if label_font:
        problem_matrix_text = "Problem Matrix"
        text_width, text_height = draw.textbbox((0, 0), problem_matrix_text, font=label_font)[2:4]
        
        # Calculate the position for the vertical text
        # Center it vertically in the matrix area
        matrix_height = 3 * cell_height + 2 * spacing
        text_x = 25  # Fixed distance from the left edge - increased from 15 to 25
        text_y = (composite_height - text_width) // 2  # Center vertically
        
        # Draw the rotated text
        # We need to create a temporary image, draw text, then rotate and paste
        text_img = Image.new('RGBA', (text_width, text_height), (255, 255, 255, 0))
        text_draw = ImageDraw.Draw(text_img)
        text_draw.text((0, 0), problem_matrix_text, fill='black', font=label_font)
        
        # Rotate the text image 90 degrees counter-clockwise
        rotated_text = text_img.rotate(90, expand=True)
        
        # Paste the rotated text onto the composite image
        composite.paste(rotated_text, (text_x, text_y), rotated_text)
    
    return composite


In [8]:
def create_choices_grid(choices: List[Image.Image]) -> Image.Image:
    """
    Create a grid showing the 8 answer choices with index labels below each choice,
    matching the reference style with 2 rows of 4 choices each.
    
    Args:
        choices: List of 8 PIL images representing answer choices
        
    Returns:
        PIL Image showing choices arranged in 2x4 grid with labels below
    """
    choice_width, choice_height = choices[0].size
    
    # Grid configuration: 2 rows, 4 columns
    grid_cols, grid_rows = 4, 2
    
    # Add border width and spacing
    border_width = 2
    spacing = 4
    label_height = 30  # Space for number labels below choices
    
    # Calculate cell dimensions
    cell_width = choice_width + 2 * border_width
    cell_height = choice_height + 2 * border_width + label_height
    
    # Calculate total dimensions
    composite_width = cell_width * grid_cols + spacing * (grid_cols - 1)
    composite_height = cell_height * grid_rows + spacing * (grid_rows - 1)
    
    composite = Image.new('RGB', (composite_width, composite_height), 'white')
    
    draw = ImageDraw.Draw(composite)
    
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 18)
    except:
        try:
            font = ImageFont.load_default()
        except:
            font = None
    
    for i, choice in enumerate(choices):
        col = i % grid_cols
        row = i // grid_cols
        
        # Calculate position with spacing
        x = col * (cell_width + spacing)
        y = row * (cell_height + spacing)
        
        # Draw border rectangle around choice
        choice_border_height = choice_height + 2 * border_width
        draw.rectangle([x, y, x + cell_width, y + choice_border_height], 
                      outline='black', fill='white', width=border_width)
        
        # Paste the choice image inside the border
        choice_x = x + border_width
        choice_y = y + border_width
        composite.paste(choice, (choice_x, choice_y))
        
        # Add index label below the choice
        label = str(i + 1)  # 1-indexed for display
        if font:
            bbox = draw.textbbox((0, 0), label, font=font)
            text_width = bbox[2] - bbox[0]
            text_height = bbox[3] - bbox[1]
        else:
            text_width, text_height = 10, 12
        
        # Center the label below the choice
        label_x = x + (cell_width - text_width) // 2
        label_y = y + choice_border_height + (label_height - text_height) // 2
        
        draw.text((label_x, label_y), label, fill='black', font=font)
    
    return composite


In [9]:
def create_combined_image(matrix_composite: Image.Image, choices_composite: Image.Image) -> Image.Image:
    """
    Combine matrix composite and choices composite into a single image matching
    the reference layout with section labels and proper spacing.
    
    Args:
        matrix_composite: PIL Image showing the 3x3 matrix
        choices_composite: PIL Image showing the 8 choices
        
    Returns:
        PIL Image with labeled sections: Problem Matrix and Answer Set
    """
    # Calculate dimensions
    matrix_width, matrix_height = matrix_composite.size
    choices_width, choices_height = choices_composite.size
    
    # Layout parameters
    margin = 30  # Margin around the entire image
    section_spacing = 40  # Space between Problem Matrix and Answer Set
    label_height = 40  # Height for section labels (increased)
    label_margin = 15  # Space between label and content (increased)
    
    # Add margin for the left side labels
    left_margin = 120  # Increased from 90 to 120 for more space for vertical text
    
    # Calculate total dimensions
    combined_width = max(matrix_width, choices_width) + 2 * margin + left_margin
    combined_height = (margin + label_height + label_margin + matrix_height + 
                      section_spacing + label_height + label_margin + 
                      choices_height + margin)
    
    # Create combined image with white background
    combined = Image.new('RGB', (combined_width, combined_height), 'white')
    draw = ImageDraw.Draw(combined)
    
    # Set up fonts - using bold font for both
    try:
        title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
        label_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 18)
    except:
        try:
            title_font = ImageFont.load_default()
            label_font = ImageFont.load_default()
        except:
            title_font = None
            label_font = None
    
    # Current y position
    current_y = margin
    
    # Add "(a)" label at top left
    if title_font:
        draw.text((margin, current_y), "(a)", fill='black', font=title_font)
    current_y += label_height + label_margin
    
    # We don't need to add "Problem Matrix" label here since it's now part of the matrix_composite
    
    # Center and paste the matrix
    matrix_x = margin + (combined_width - matrix_width - 2*margin) // 2
    combined.paste(matrix_composite, (matrix_x, current_y))
    current_y += matrix_height + section_spacing
    
    # Add "Answer Set" label - positioned on the left side vertically
    if label_font:
        answer_set_label = "Answer Set"
        # Create a temporary image for the rotated text
        bbox = draw.textbbox((0, 0), answer_set_label, font=label_font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
        
        # Create temporary image for text with transparent background
        text_img = Image.new('RGBA', (text_width + 10, text_height + 10), (255, 255, 255, 0))  # Added padding
        text_draw = ImageDraw.Draw(text_img)
        text_draw.text((5, 5), answer_set_label, fill='black', font=label_font)  # Add some padding inside the text image
        
        # Rotate the text image 90 degrees counter-clockwise
        rotated_text = text_img.rotate(90, expand=True)
        
        # Position the rotated text on the left side, centered vertically with choices grid
        # Ensure text is more centered and not cut off
        label_x = 45  # Increased from 35 to 45 for better spacing
        
        # Calculate better vertical centering based on the actual choices area
        # Make sure the text is fully visible within the choices section
        choices_center_y = current_y + choices_height / 2
        rotated_text_height = rotated_text.size[1]
        label_y = choices_center_y - rotated_text_height / 2
        
        # Ensure the text doesn't go outside the image boundaries
        if label_y < current_y:
            label_y = current_y
        if label_y + rotated_text_height > current_y + choices_height:
            label_y = current_y + choices_height - rotated_text_height
            
        combined.paste(rotated_text, (label_x, int(label_y)), rotated_text)
    
    # Center and paste the choices
    choices_x = margin + (combined_width - choices_width - 2*margin) // 2
    combined.paste(choices_composite, (choices_x, current_y))
    
    return combined


In [17]:
# Load the RAVEN dataset - for all required subsets
subsets = ["center_single", "distribute_four", "distribute_nine"]

# have a different shape
# subsets = ["left_center_single_right_center_single", "up_center_single_down_center_single", "in_center_single_out_center_single", "in_distribute_four_out_center_single"]
datasets = {}

for subset in subsets:
    datasets[subset] = load_dataset("HuggingFaceM4/RAVEN", subset)
    print(f"\nLoaded {subset} dataset with the following splits: {datasets[subset].keys()}")
    print(f"Training examples: {len(datasets[subset]['train'])}")
    print(f"Validation examples: {len(datasets[subset]['validation'])}")



Loaded center_single dataset with the following splits: dict_keys(['train', 'validation', 'test'])
Training examples: 6000
Validation examples: 2000

Loaded distribute_four dataset with the following splits: dict_keys(['train', 'validation', 'test'])
Training examples: 6000
Validation examples: 2000

Loaded distribute_nine dataset with the following splits: dict_keys(['train', 'validation', 'test'])
Training examples: 6000
Validation examples: 2000


In [18]:
# Check the structure of a sample from each dataset
for subset in subsets:
    sample = datasets[subset]['train'][0]
    print(f"\n=== Sample from {subset} ===")
    print("Sample keys:", sample.keys())
    
    # Print shape of panels and choices
    print("Panels length:", len(sample['panels']))
    print("Choices length:", len(sample['choices']))
    print("Target:", sample['target'])
    


=== Sample from center_single ===
Sample keys: dict_keys(['panels', 'choices', 'structure', 'meta_matrix', 'meta_target', 'meta_structure', 'target', 'id', 'metadata'])
Panels length: 8
Choices length: 8
Target: 4

=== Sample from distribute_four ===
Sample keys: dict_keys(['panels', 'choices', 'structure', 'meta_matrix', 'meta_target', 'meta_structure', 'target', 'id', 'metadata'])
Panels length: 8
Choices length: 8
Target: 7

=== Sample from distribute_nine ===
Sample keys: dict_keys(['panels', 'choices', 'structure', 'meta_matrix', 'meta_target', 'meta_structure', 'target', 'id', 'metadata'])
Panels length: 8
Choices length: 8
Target: 6
