In [None]:
import os
import sys

# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [None]:
import torch

from src.prepare_data import load_dataset

input_dump = torch.load("../temp/0000_0002_adaptive_train_input_ids.pt", map_location=torch.device('cpu'))

input_tensor = input_dump['input_ids']['data']

# Set both precision and threshold
torch.set_printoptions(precision=1, threshold=1024)

print('input_tensor: ', input_tensor.shape)

### Function to plot input/output pairs of a task

In [None]:
from   matplotlib import colors
import matplotlib.pyplot as plt

# 0:black, 1:blue, 2:red, 3:green, 4:yellow, # 5:gray, 6:magenta, 7:orange, 8:sky, 9:brown

cmap = colors.ListedColormap(
    ['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
     '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
norm = colors.Normalize(vmin=0, vmax=9)

plt.figure(figsize=(3, 1), dpi=150)
plt.imshow([list(range(10))], cmap=cmap, norm=norm)
plt.xticks(list(range(10)))
plt.yticks([])
plt.show()

In [None]:
def plot_task(task, i, t):
    """    Plots the first train and test pairs of a specified task,
    using same color scheme as the ARC app    """    
    
    num_train = len(task['train'])
    
    w = num_train
    fig, axs = plt.subplots(4, w, figsize=(5*w, 5*2))
    plt.suptitle(f'Set #{i}, {t}:', fontsize=20, fontweight='bold', y=1)

    print('num_train', num_train)

    if num_train > 1:
        for j in range(num_train):     
            plot_one(axs[0, j], task, j, 'train', 'input', 0)
            plot_one(axs[1, j], task, j, 'train', 'input', 3)        
            plot_one(axs[2, j], task, j, 'train', 'output', 0)        
            plot_one(axs[3, j], task, j, 'train', 'output', 3)        
    else:
        plot_one(axs[0], task, 0, 'train', 'input', 0)
        plot_one(axs[1], task, 0, 'train', 'input', 3)        
        plot_one(axs[2], task, 0, 'train', 'output', 0)        
        plot_one(axs[3], task, 0, 'train', 'output', 3)        
       
    fig.patch.set_linewidth(5)
    fig.patch.set_edgecolor('black') 
    fig.patch.set_facecolor('#dddddd')
   
    plt.tight_layout()
    
    plt.show()  

def plot_one(ax, task_data, i, train_or_test, input_or_output, cell_index):
    try:
        input_matrix = task_data[train_or_test][i][input_or_output]
        # print('input_matrix', input_matrix)

        input_matrix = [[inner_list[cell_index] for inner_list in outer_list] for outer_list in input_matrix]
        # print('input_matrix', input_matrix)
    except:
        return
        
    im = ax.imshow(input_matrix, cmap=cmap, norm=norm)
    ax.grid(True, which='both', color='lightgrey', linewidth=0.5)
    
    plt.setp(plt.gcf().get_axes(), xticklabels=[], yticklabels=[])
    ax.set_xticks([x-0.5 for x in range(1 + len(input_matrix[0]))])     
    ax.set_yticks([x-0.5 for x in range(1 + len(input_matrix))])
    
    # Calculate font size based on grid dimensions
    grid_size = max(len(input_matrix), len(input_matrix[0]))
    base_font_size = 10  # Base font size for small grids
    min_font_size = 4    # Minimum font size to ensure readability
    font_size = max(base_font_size - (grid_size - 5) * 0.5, min_font_size)
    
    # Add text annotations with adjusted font size
    for y in range(len(input_matrix)):
        for x in range(len(input_matrix[0])):
            value = input_matrix[y][x]
            text_color = 'white' if value > 5 or value == 0 else 'black'
            ax.text(x, y, str(value), ha='center', va='center', color=text_color, fontsize=font_size)
    
    ax.set_title(f'{train_or_test} {input_or_output}')

In [None]:
from pprint import pprint
from pprint import pformat
from src.prepare_data import load_dataset
from src.token import SpecialToken
import random

torch.set_printoptions(threshold=float('inf'), linewidth=1000)

def detokenize_grid(tokenized_sequence):
    # print('detokenize_grid', tokenized_sequence)
    grid = []
    current_row = []
    for cell in tokenized_sequence:
        if cell[0] == SpecialToken.ROW_SEPARATOR.value:
            if current_row:
                grid.append(current_row)
                current_row = []
        else:
            current_row.append(cell)
    if current_row:
        grid.append(current_row)
    return grid
    
# Iterate over the shuffled indices
for i in range(input_tensor.shape[0]): # indices:
    sequence = input_tensor[i]

    # print('sequence', sequence)
    # Extract task data
    task = {'train': []}
    solution = None
    
    current_section = None
    current_data = []
    
    for tensor in sequence:
        token = tensor.tolist()
        if token[0] == SpecialToken.START_INPUT.value:
            current_section = 'input'
            current_data = []
        elif token[0] == SpecialToken.END_INPUT.value:
            if current_section == 'input':
                task['train'].append({'input': detokenize_grid(current_data)})
        elif token[0] == SpecialToken.START_OUTPUT.value:
            current_section = 'output'
            current_data = []
        elif token[0] == SpecialToken.END_OUTPUT.value:
            if current_section == 'output':
                task['train'][-1]['output'] = detokenize_grid(current_data)
        else:
            if token[0] < 10 or token[0] == SpecialToken.ROW_SEPARATOR.value:
                current_data.append(token)
            else:
                assert token[0] in (SpecialToken.START.value, SpecialToken.END.value, SpecialToken.PAD.value), token
                
    # Plot task
    plot_task(task, i, f"{i}")
    
    # print('detokenized', task)
    # print('fetched', sequence)    