In [None]:
import os
import sys

# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [None]:
from src.prepare_data import load_dataset

# dataset_file = "../intermediate_data/prepared_dataset_using_arc_training_second.pth"
dataset_file = "../intermediate_data/prepared_dataset_using_barc.pth" # _using_arc_training
dataset, data_sources, source_ranges = load_dataset(dataset_file)

print(data_sources, source_ranges)

### Function to plot input/output pairs of a task

In [None]:
from   matplotlib import colors
import matplotlib.pyplot as plt

# 0:black, 1:blue, 2:red, 3:green, 4:yellow, # 5:gray, 6:magenta, 7:orange, 8:sky, 9:brown

cmap = colors.ListedColormap(
    ['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
     '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
norm = colors.Normalize(vmin=0, vmax=9)

plt.figure(figsize=(3, 1), dpi=150)
plt.imshow([list(range(10))], cmap=cmap, norm=norm)
plt.xticks(list(range(10)))
plt.yticks([])
plt.show()

In [None]:
def plot_task(task, i, t):
    """    Plots the first train and test pairs of a specified task,
    using same color scheme as the ARC app    """    
    
    num_train = len(task['train'])
    
    w = num_train
    fig, axs = plt.subplots(2, w, figsize=(3*w, 3*2))
    plt.suptitle(f'Set #{i}, {t}:', fontsize=20, fontweight='bold', y=1)
    
    for j in range(num_train):     
        plot_one(axs[0, j], task, j, 'train', 'input')
        plot_one(axs[1, j], task, j, 'train', 'output')        
       
    fig.patch.set_linewidth(5)
    fig.patch.set_edgecolor('black') 
    fig.patch.set_facecolor('#dddddd')
   
    plt.tight_layout()
    
    plt.show()  

def plot_one(ax, task_data, i, train_or_test, input_or_output):
    input_matrix = task_data[train_or_test][i][input_or_output]
    im = ax.imshow(input_matrix, cmap=cmap, norm=norm)
    ax.grid(True, which='both', color='lightgrey', linewidth=0.5)
    
    plt.setp(plt.gcf().get_axes(), xticklabels=[], yticklabels=[])
    ax.set_xticks([x-0.5 for x in range(1 + len(input_matrix[0]))])     
    ax.set_yticks([x-0.5 for x in range(1 + len(input_matrix))])
    
    # Add text annotations
    for y in range(len(input_matrix)):
        for x in range(len(input_matrix[0])):
            value = input_matrix[y][x]
            text_color = 'white' if value > 5 or value == 0 else 'black'  # Adjust this threshold as needed
            ax.text(x, y, str(value), ha='center', va='center', color=text_color)
    
    ax.set_title(f'{train_or_test} {input_or_output}')

In [None]:
from pprint import pprint
from pprint import pformat
from src.prepare_data import load_dataset
from src.token import SpecialToken
import random
from src.utils.helper import detokenize_grid

def find_source(index):
    for source, (start, end) in source_ranges.items():
        if start <= index <= end:
            return source
    return "Unknown source"

In [None]:
# Create a list of indices and shuffle it
# indices = list(range(*list(source_ranges.values())[-1]))

# random.shuffle(indices)

dataset.set_augment_seed(0)

# Iterate over the shuffled indices
for i in range(0, len(dataset), 1000): # range(0, 9000, 6000): # indices:
    source = find_source(i)
    sequence = dataset[i]

    # print('sequence', sequence['task'])
    # Extract task data
    task = {'train': []}
    solution = None
    
    current_section = None
    current_data = []
    
    for element in sequence['task']:
        token = element[0]
        if token == SpecialToken.START_INPUT.value:
            current_section = 'input'
            current_data = []
        elif token == SpecialToken.END_INPUT.value:
            if current_section == 'input':
                task['train'].append({'input': detokenize_grid(current_data)})
        elif token == SpecialToken.START_OUTPUT.value:
            current_section = 'output'
            current_data = []
        elif token == SpecialToken.END_OUTPUT.value:
            if current_section == 'output':
                task['train'][-1]['output'] = detokenize_grid(current_data)
        elif token == SpecialToken.END.value:
            break
        else:
            current_data.append(token)

    # print('task', task)
    # Plot task
    plot_task(task, i, f"{source}, {i}")
    
    # Print task data
    # print('#train', pformat(task['train']).replace('\n', ''))
    # print('original', dataset.data[i])
    # print('fetched', sequence)