In [None]:
from arc_prize.train import ARCModelState, ARCTrainParams
from arc_prize.vis import visualize_epochs
import modal
import torch
import petname
from arc_prize.model import ARCTransformerEncoderDecoderParams



In [31]:

model_params = ARCTransformerEncoderDecoderParams(
  grid_dim=30,
  num_train_pairs=10,
  num_colors=10,
  num_encoder_layers=2,
  num_decoder_layers=2,
  num_heads=4,
  d_model=32,
  d_ff=32*2,
  dropout=0.2
)

train_params = ARCTrainParams(
  batch_size=5,
  learning_rate=1e-3,
  weight_decay=1e-4,
  dataset_dir=["/vol/data/arc"],
  loss_class_weights={0: 0.2}
)

num_epochs = 10

model_names = []

num_runs = 1

fn = modal.Function.lookup("arc-prize", "train")
for i in range(num_runs):
  model_name = petname.generate(words=3, separator='_')
  fn_call = fn.spawn(model_name, num_epochs, model_params, train_params)
  # train_on_mac(model_name, num_epochs, model_params, train_params)
  print("Model name", model_name, fn_call.object_id)
  model_names.append(model_name)

print(model_names)




Model name fully_secure_fly fc-01J5NK97HAS61E47JJAEGZ02NZ
['fully_secure_fly']


In [None]:


from arc_prize.model import ARCTransformerEncoderDecoderParams
from arc_prize.train import train_on_mac


model_params = ARCTransformerEncoderDecoderParams(
  grid_dim=20,
  num_train_pairs=4,
  num_colors=10,
  num_encoder_layers=2,
  num_decoder_layers=2,
  num_heads=4,
  d_model=32,
  d_ff=32*2,
  dropout=0.2
)

train_params = ARCTrainParams(
  batch_size=10,
  learning_rate=1e-3,
  weight_decay=1e-4,
  dataset_dir=["data/move_diagonal"]
)

num_epochs = 100

model_names = []

num_runs = 3

# fn = modal.Function.lookup("arc-prize", "train")
for i in range(num_runs):
  model_name = petname.generate(words=3, separator='_')
  # fn_call = fn.spawn(model_name, num_epochs, model_params, train_params)
  train_on_mac(model_name, num_epochs, model_params, train_params)
  # print("Model name", model_name, fn_call.object_id)
  model_names.append(model_name)

print(model_names)




In [None]:
num_epochs = 100

model_names = ["fully_solid_fly"]

fn = modal.Function.lookup("arc-prize", "train")
for model_name in model_names:
  fn_call = fn.spawn(model_name, num_epochs, None, None)
  print("Model name", model_name, fn_call.object_id)
  


In [None]:
from arc_prize.vis import visualize_all_heads


def visualize_group(model_names: list[str]):
  epochs = {}
  get_model = modal.Function.lookup("arc-prize", "get_model")
  for name in model_names:
    checkpoint = ARCModelState(**get_model.remote(name))
    print(name, len(checkpoint.epochs), checkpoint.epochs[-1], checkpoint.model_params)
    epochs[name] = checkpoint.epochs


    # print(len(checkpoint.encoder_attn_weights))
    # for b, batch in enumerate(checkpoint.encoder_attn_weights):
    #   for i, layer in enumerate(batch):
    #     visualize_all_heads(layer, title=f"Batch {b}, layer {i}")
    

  visualize_epochs(epochs)

  


groups = [
  # ['kindly_huge_jennet', 'lovely_tidy_lab', 'solely_living_leech'], # BEST
  # ['weekly_enough_moose', 'gently_known_beagle', 'nicely_robust_rhino'], # 20x20 too slow
  # ['wildly_firm_husky', 'surely_brief_bug', 'fully_better_dodo'], # Amazing
  # ['wildly_steady_iguana', 'yearly_smart_donkey', 'mainly_polite_bison'], # Includes scale dataset
  # ['partly_vocal_piglet', 'neatly_needed_liger', 'firmly_game_weevil'], # Scale and diagonal
  ['wholly_tops_heron', 'solely_eager_foal', 'deeply_one_skink'], # Tons of data
  ['unduly_glad_swift', 'purely_steady_hornet', 'humbly_civil_donkey'], # Basic data
]

# print([group for sublist in groups for group in sublist])
for group in groups:
  visualize_group(group)

In [None]:
from arc_prize.vis import visualize_tensors, visualize_all_heads


eval_model = modal.Function.lookup("arc-prize", "evaluate_model")
output = eval_model.remote("deeply_one_skink", ["/vol/data/move_diagonal_and_scale"])


In [None]:





# from arc_prize.vis import visualize_mean_mha_attention



for item in output:
  visualize_tensors(torch.Tensor(item["grids"]).squeeze(0), torch.Tensor(item["output_grid"]).squeeze(0), torch.Tensor(item["predictions"]).squeeze(0))
# print(torch.Tensor(item["decoder_sa_attn_weights"]).shape)
# for i, layer in enumerate(torch.Tensor(item["decoder_mha_attn_weights"]).squeeze(0)):
#   mean_attention = layer.mean(dim=1)
#   print(mean_attention.shape)
#   visualize_mean_mha_attention(layer)
  # mean_attention = mean_attention.view(4, 9, 10, 10)
  # print(mean_attention.shape)

  # visualize_all_heads(layer, title=f"Layer {i}")
# for i, layer in enumerate(torch.Tensor(item["decoder_sa_attn_weights"]).squeeze(0)):
#     visualize_mean_sa_attention(layer)
    


In [None]:
from arc_prize.vis import visualize_output_query, visualize_tensors


# model = ARCTransformer(d_model=d_model, num_heads=num_heads, num_layers=num_layers, d_ff=dim_feedforward, grid_dim=max_grid_size, num_colors=num_colors, num_train_pairs=max_context_pairs, dropout=dropout).to(device)

model_file_name = "models/model_75i3sirg.pth"
if model_file_name is not None:
    state_dict = torch.load(model_file_name, map_location=device)
    model.load_state_dict(state_dict)

model.eval()
eval_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, collate_fn=collate_arc_fn, num_workers=0)
# batch = next(iter(eval_loader))

# visualize_output_query(model.output_query)


for i, batch in enumerate(eval_loader):
    grids, grid_masks, output_grid = [item.to(device) for item in batch]

    predictions = model.generate(grids, grid_masks)
    print(predictions.shape)

    visualize_tensors(grids.squeeze(0), output_grid.squeeze(0), predictions.squeeze(0))

