In [1]:
import torch
from arc_prize.model import ARCTransformerEncoder
from arc_prize.train import ARCModelState

model_name = "kindly_exact_beagle_5"
model_klass = ARCTransformerEncoder

model_filename = f"/Users/pfh/work/arc-models/{model_name}.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint_dict = torch.load(model_filename, weights_only=False, map_location=device)
checkpoint = ARCModelState(**checkpoint_dict)

model = model_klass(checkpoint.model_params)

if checkpoint.model_state_dict is not None:
    model.load_state_dict(checkpoint.model_state_dict)

In [2]:
from arc_prize.data import ARCDatasetParams, ARCKaggleDataset, make_datasets

data_params = ARCDatasetParams(max_grid_size=12, max_train_grids=4, color_offset=1)
_, eval_dataset = make_datasets(["/Users/pfh/work/arc-data/eval_dim_12"], data_params)

In [None]:
# 9, 14, 36, 41
task = eval_dataset[41]
print(task["task_id"])
print(task.keys())



In [None]:
from arc_prize.vis import visualize_tensors


visualize_tensors(task["grids"], task["output"], None, None)

In [None]:
output = model.forward(task["grids"].unsqueeze(0), task["masks"].unsqueeze(0), temperature=0.2, need_intermediate_outputs=True)
prediction, secondary, _, _ = output

In [242]:
embedded_src = model.embedding.forward(task["grids"].unsqueeze(0))
output_query = model.output_query.expand(1, -1, -1, -1, -1)
combined_input = torch.cat([embedded_src, output_query], dim=1)
pos_enc = model.pos_encoding.forward(combined_input)

In [None]:
print(combined_input.shape)

In [None]:
print(prediction.shape, secondary.shape)
intermediate_predictions = []
for layer_out in secondary[0]:
  print(layer_out.shape)
  output_grid_portion = layer_out[-model.output_seq_len :]

  # Project to vocabulary space
  logits = model.output_layer(output_grid_portion)

  # Reshape to grid format
  pred = logits.view(1, model.grid_dim, model.grid_dim, model.num_classes)

  probs = torch.softmax(pred, dim=-1)
  sample = torch.multinomial(
      probs.view(-1, probs.size(-1)),
      num_samples=1,
      replacement=True,
  ).view(-1, *probs.size()[:-1])
  
  intermediate_predictions.append(sample[0][0])

In [None]:
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

COLORS = [
    "#c2c0c0",  # padding grey
    "#111111",  # black
    "#1E93FF",  # blue
    "#F93C31",  # red
    "#4FCC30",  # green
    "#FFDC00",  # yellow
    "#E6E6E6",  # grey
    "#E53AA3",  # magenta
    "#FF851B",  # orange
    "#87D8F1",  # light blue
    "#921231",  # maroon
    "#FFFFFF",
]

def display_grids(grids: list):
    fig, axes = plt.subplots(4, 4, figsize=(20, 20))

    cmap = mcolors.ListedColormap(COLORS)

    for i, grid in enumerate(grids):
      row = i // 4
      col = i % 4
      ax = axes[row][col]
      im = ax.imshow(grid, cmap=cmap, vmin=0, vmax=len(COLORS) - 1)
      ax.set_xticks(np.arange(-0.5, grid.shape[1], 1), minor=True)
      ax.set_yticks(np.arange(-0.5, grid.shape[0], 1), minor=True)
      ax.grid(which="minor", color="lightgrey", linestyle="-", linewidth=0.5)
      ax.tick_params(
          which="both", bottom=False, left=False, labelbottom=False, labelleft=False
      )
      # ax.set_xticks([])
      # ax.set_yticks([])
    
    plt.tight_layout()
    plt.show()

def display_grid(grid):
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    cmap = mcolors.ListedColormap(COLORS)

    # ax = axes[0]
    im = ax.imshow(grid, cmap=cmap, vmin=0, vmax=len(COLORS) - 1)
    ax.set_xticks(np.arange(-0.5, grid.shape[1], 1), minor=True)
    ax.set_yticks(np.arange(-0.5, grid.shape[0], 1), minor=True)
    ax.grid(which="minor", color="lightgrey", linestyle="-", linewidth=0.5)
    ax.tick_params(
        which="both", bottom=False, left=False, labelbottom=False, labelleft=False
    )
      # ax.set_xticks([])
      # ax.set_yticks([])
    
    plt.tight_layout()
    plt.show()

def display_sequence(grid):
    seq = grid.view(1, -1)
    fig, ax = plt.subplots(1, 1, figsize=(100, 10))

    cmap = mcolors.ListedColormap(COLORS)

    

    im = ax.imshow(seq, cmap=cmap, vmin=0, vmax=len(COLORS) - 1, aspect=20.0)
    ax.set_xticks(np.arange(-0.5, seq.shape[1], 1), minor=True)
    # ax.set_yticks(np.arange(-0.5, seq.shape[1], 1), minor=True)
    ax.set_yticks([])
    ax.grid(which="minor", color="lightgrey", linestyle="-", linewidth=0.5)
    ax.tick_params(
        which="both", bottom=False, left=False, labelbottom=False, labelleft=False
    )
        # ax.set_xticks([])
        # ax.set_yticks([])

    plt.tight_layout()
    plt.show()

# display_grids(intermediate_predictions)
# for interm in intermediate_predictions:
#     display_grid(interm)


display_sequence(task["grids"][:3].view(1, -1))
display_sequence(task["grids"][6:].view(1, -1))
display_sequence(torch.Tensor(np.random.randint(0, 11, size=(12,12))))
display_sequence(task["output"])
# display_grid(task["output"])

display_sequence(task["output"])
for grid in task["grids"]:
#    display_grid(grid)
    display_sequence(grid)
display_grid(torch.zeros(12, 12))
