In [None]:
import json
import numpy as np

with open("data/move_right/arc-synth_move_right_training_challenges.json", "r") as f:
  tasks = json.load(f)
  max_height = 0
  max_width = 0
  max_count = 0
  for task in tasks.values():
    count = len(task["train"])
    if count > max_count:
      max_count = count
    for pair in task["train"]:
      input = np.array(pair["input"])
      height, width = input.shape
      if height > max_height:
        max_height = height
      if width > max_width:
        max_width = width
  
  print(max_count, max_height, max_width)
      


In [3]:
from arc_prize.env import modal_app
from arc_prize.model import ARCTransformer
import torch
from torch.utils.data import DataLoader
from arc_prize.data import ARCDataset, ARCDatasetConfig, collate_arc_fn

# Hyperparameters
d_model = 128
num_encoder_layers = 4
num_decoder_layers = 4
dim_feedforward = d_model * 4 # Using 4x d_model heuristic for now
max_grid_size = 10 # 30
num_heads = 8
max_context_pairs = 4 # 10
batch_size = 10
num_epochs = 100
num_colors = 10
learning_rate = 1e-3
dropout = 0.1
weight_decay = 1e-5

synth_arc_dataset_config = ARCDatasetConfig(max_grid_size=max_grid_size, max_train_grids=max_context_pairs, color_offset=1)

# dataset_prefix = "data/move_right/arc-synth_move_right"
dataset_prefix = "data/move_random/arc-synth_move_random"

train_dataset = ARCDataset(f"{dataset_prefix}_training_challenges.json", f"{dataset_prefix}_training_solutions.json", config=synth_arc_dataset_config)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_arc_fn, num_workers=0)

val_dataset = ARCDataset(f"{dataset_prefix}_evaluation_challenges.json", f"{dataset_prefix}_evaluation_solutions.json", config=synth_arc_dataset_config)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_arc_fn, num_workers=0)

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ARCTransformer(d_model=d_model, num_heads=num_heads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, d_ff=dim_feedforward, grid_dim=max_grid_size, num_colors=num_colors, num_train_pairs=max_context_pairs, dropout=dropout).to(device)




In [None]:
from arc_prize.train import train_on_modal, train_arc_transformer


# train_arc_transformer(model, train_loader, val_loader, num_epochs, learning_rate, weight_decay)
with modal_app.run():
    modal_app.run(
        train_on_modal.remote(
            model, train_loader, val_loader, num_epochs, learning_rate, weight_decay
        ),
        show_progress=False,
    )

print("Training completed and model saved.")

In [None]:
from arc_prize.vis import visualize_tensors


# model = ARCTransformer(d_model=d_model, num_heads=num_heads, num_layers=num_layers, d_ff=dim_feedforward, grid_dim=max_grid_size, num_colors=num_colors, num_train_pairs=max_context_pairs, dropout=dropout).to(device)

model_file_name = "models/model_4q9m0e3x.pth"
if model_file_name is not None:
    state_dict = torch.load(model_file_name, map_location=device)
    model.load_state_dict(state_dict)

model.eval()
eval_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, collate_fn=collate_arc_fn, num_workers=0)
# batch = next(iter(eval_loader))



for i, batch in enumerate(eval_loader):
    grids, grid_masks, output_grid = [item.to(device) for item in batch]

    predictions = model.generate(grids, grid_masks)
    print(predictions.shape)

    visualize_tensors(grids.squeeze(0), output_grid.squeeze(0), predictions.squeeze(0))

