In [None]:
import os
import sys
from plot_funcs import plot_task
from   matplotlib import colors
import matplotlib.pyplot as plt

# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [None]:
import argparse
import torch
from src.model import Transformer
from src.token import SpecialToken, VOCAB_SIZE
from src.load_data import load_from_json, GridDataset
from src.utils.helper import set_deterministic
from src.utils.display_diff import compare_sequences, colorize, split_into_chunks
from src.utils.transformer_helper import create_mask

from checkpoint_helper import format_batch, plot_task, plot_answer

debug_on_cpu = False

def generate_sample(model, input_sequence, max_length, device, mask_hack):
    model.eval()
    y = 0
    x = 0
    coord = (-1, -1)
    with torch.no_grad():
        input_ids = torch.tensor(input_sequence, dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_length)

        seq_length = len(input_sequence)
        for generated_token_index in range(max_length - len(input_sequence)):
            mask = create_mask(input_ids, device, [seq_length], mask_hack)
            outputs = model(input_ids, mask)  # (1, seq_length, vocab_size)
            next_token_logits = outputs[0, -1, :]  # (vocab_size)
            next_token_id = torch.argmax(next_token_logits).item()

            if next_token_id < SpecialToken.CELL_TOKEN_SIZE.value:
                coord = (y, x)
                x = x + 1
                x = min(x, model.max_grid_size - 1)
            elif next_token_id == SpecialToken.ROW_SEPARATOR.value:
                coord = (y, x)
                x = 0
                y = y + 1
                y = min(y, model.max_grid_size - 1)
            else:
                y = 0
                x = 0
                coord = (-1, -1)

            input_ids = torch.cat([input_ids, torch.tensor([[[next_token_id, coord[0], coord[1], -1, -1]]], dtype=torch.long, device=device)], dim=1)  # (1, seq_length + 1)
            if next_token_id == SpecialToken.END.value:
                break

            # print(f'next_token_id{seq_length + generated_token_index}, {next_token_id}')
            
        return input_ids.squeeze(0).tolist()  # (seq_length)


In [None]:
from src.checkpoint_handler import CheckpointHandler

set_deterministic()

# Load data
data_sources = ['arc-agi_evaluation']
all_challenges = {}

for source in data_sources:
    try:
        challenges, solutions = load_from_json(source, '../input_data/')
        all_challenges.update(challenges)
    except FileNotFoundError as e:
        print(f"Error loading {source}: {e}. Skipping this data source.")

if not all_challenges:
    print("No data could be loaded. Please check the file paths and data sources.")
    
device = torch.device("cuda" if (torch.cuda.is_available() and not debug_on_cpu) else "cpu")
checkpoint_path='../cloud_runs/69.55.141.119/genesis/runs/genesis/20241026_150711_nogit_nobranch_lr2e-06_bl2e-06_ssu0_bs22_h4_es784_nl18_we10_as1_ph3_ac1_ad1_scosine_oadam_ge1_c43/Transformer_latest.pt'
model, max_seq_length, checkpoint_args = CheckpointHandler.load_checkpoint_in_production(checkpoint_path, device, adjust_max_length=12000)
mask_hack = checkpoint_args.get('mask_hack', True)

# Create the dataset
dataset_ref = GridDataset.load_from_paired_file(all_challenges, solutions, second_only=True)
dataset = GridDataset.load_from_paired_file(all_challenges, None, second_only=True)

print('dataset', len(dataset), type(dataset[0]))
print('dataset_ref', len(dataset_ref), type(dataset_ref[0]))
print('max_seq_length', max_seq_length)

mismatch_count = 0    
tested_count = 0
# Generate samples
start_index = 0
for i in range(start_index, len(dataset)):
    print('index', i)
    input_sequence = dataset[i]

    if max_seq_length > len(input_sequence):
        tested_count += 1
        
        sample = generate_sample(model, input_sequence['task'], max_seq_length, device, mask_hack)        
        plot_answer(input_sequence['task'], sample, i)

        expected_sequence = dataset_ref[i]
        plot_answer(expected_sequence['task'], expected_sequence['task'], i)