# Prepare arc-like datasets for training

## Goal

Transform all external datasets to a format that is the same as the ARC24 competition.

## Imports

In [None]:
import os
import glob
import json
import random
import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib as mpl
import numpy as np
from tqdm.auto import tqdm

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (15, 4)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

## Code

In [None]:
def plot_task(task, task_id):
    all_samples = task['train'] + task['test']
    for plot_idx, sample in enumerate(all_samples):
        plt.subplot(1, len(all_samples), plot_idx+1)
        plot_grid(sample['input'])
        if plot_idx < len(task['train']):
            plt.title(f'train {plot_idx}')
        else:
            plt.title(f'test {plot_idx-len(task["train"])}')
    plt.suptitle(f'Inputs for task {task_id}')
    plt.show()
    for plot_idx, sample in enumerate(all_samples):
        plt.subplot(1, len(all_samples), plot_idx+1)
        plot_grid(sample['output'])
        if plot_idx < len(task['train']):
            plt.title(f'train {plot_idx}')
        else:
            plt.title(f'test {plot_idx-len(task["train"])}')
    plt.suptitle(f'Outputs for task {task_id}')
    plt.show()


def plot_grid(grid):
    grid = np.array(grid, dtype=int)
    cmap = colors.ListedColormap(
        ['#000000', '#0074D9','#FF4136','#2ECC40','#FFDC00',
         '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
    norm = colors.Normalize(vmin=0, vmax=9)
    plt.imshow(grid, cmap=cmap, norm=norm)
    plt.grid(True,which='both',color='lightgrey', linewidth=0.5)
    plt.xticks(np.arange(-0.5, grid.shape[1]), [])
    plt.yticks(np.arange(-0.5, grid.shape[0]), [])
    plt.xlim(-0.5, grid.shape[1]-0.5)

    for i in range(grid.shape[0]):
        for j in range(grid.shape[1]):
            plt.text(j, i, grid[i, j], ha='center', va='center')

In [None]:
def plot_sample_tasks(tasks, n):
    sampled_tasks_ids = np.random.choice(list(tasks.keys()), n)
    for task_id in sampled_tasks_ids:
        plot_task(tasks[task_id], task_id)

In [None]:
def create_single_output_tasks(tasks):
    single_output_tasks = {}
    for task_id, task in tasks.items():
        for sample_idx, sample in enumerate(task['test']):
            single_output_tasks[f'{task_id}_{sample_idx}'] = {'train': task['train'], 'test': [sample]}
    return single_output_tasks

## [ConceptARC](https://github.com/victorvikram/ConceptARC/tree/main)

In [None]:
def collect_concept_arc_tasks(dataset_dir):
    filepaths = sorted(glob.glob(os.path.join(dataset_dir, 'corpus', '*', '*.json')))
    tasks = {}
    for filepath in tqdm(filepaths, desc='Loading tasks'):
        with open(filepath, 'r') as f:
            task = json.load(f)
        task_id = os.path.basename(filepath).split('.')[0]
        assert task_id not in tasks
        tasks[task_id] = task
    return tasks

In [None]:
tasks = collect_concept_arc_tasks('/mnt/hdd0/Kaggle/arc24/data/arc-like_datasets/ConceptARC/')

In [None]:
print(f'There are {len(tasks)} tasks in the dataset')
tasks = create_single_output_tasks(tasks)
print(f'There are {len(tasks)} after creating single output tasks')

In [None]:
plot_sample_tasks(tasks, n=5)

In [None]:
with open('/mnt/hdd0/Kaggle/arc24/data/arc-like_datasets/ConceptARC.json', 'w') as f:
    json.dump(tasks, f)