In [None]:
import json
import numpy as np

with open("data/arc-agi_training_challenges.json", "r") as f:
  tasks = json.load(f)
  max_height = 0
  max_width = 0
  max_count = 0
  for task in tasks.values():
    count = len(task["train"])
    if count > max_count:
      max_count = count
    for pair in task["train"]:
      input = np.array(pair["input"])
      height, width = input.shape
      if height > max_height:
        max_height = height
      if width > max_width:
        max_width = width
  
  print(max_count, max_height, max_width)
      


In [None]:
# Generate the dataset
from arc_prize.synth_data import generate_dataset
from arc_prize.vis import visualize_grids


challenges, solutions = generate_dataset(num_tasks=20)
eval_challenges, eval_solutions = generate_dataset(num_tasks=10)

file_path_root = "data/arc-synth_move_right"
with open(f"{file_path_root}_training_challenges.json", "w") as f:
   json.dump(challenges, f)
with open(f"{file_path_root}_training_solutions.json", "w") as f:
   json.dump(solutions, f)
with open(f"{file_path_root}_evaluation_challenges.json", "w") as f:
   json.dump(eval_challenges, f)
with open(f"{file_path_root}_evaluation_solutions.json", "w") as f:
   json.dump(eval_solutions, f)



# Print a sample task to verify the format
sample_task = next(iter(challenges.values()))

for task in iter(challenges.values()):
  visualize_grids(task["train"], task["test"][0]["input"])



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ARCTransformer(nn.Module):
    def __init__(self, grid_dim=30, num_train_pairs=10, num_colors=10, num_layers=4, num_heads=8, d_model=256, d_ff=1024, dropout=0.1):
        super(ARCTransformer, self).__init__()
        
        self.input_dim = grid_dim
        self.num_classes = num_colors + 1 # Add padding 
        self.d_model = d_model
        self.num_train_pairs = num_train_pairs
        self.num_heads = num_heads
        
        self.embedding = nn.Embedding(self.num_classes, d_model)
        self.pos_encoding = TrainablePositionalEncoding(d_model, max_len=(num_train_pairs * 2 + 1)*grid_dim*grid_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=d_ff, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Enable gradient checkpointing
        self.transformer_encoder.use_checkpoint = True
        
        self.output_layer = nn.Linear(d_model, self.num_classes)
    
    def forward(self, grids, masks, output=None):
        batch_size, num_grids, height, width = grids.size()
        print("grids", grids.size)
        
        # Apply grid masks
        # masked_grids = grids * masks
        masked_grids = torch.where(masks, grids, torch.zeros_like(grids))

        print(f"Min value in grids: {grids.min()}")
        print(f"Max value in grids: {grids.max()}")
        print(f"Unique values in grids: {torch.unique(grids)}")
        print("Masked grids min/max:", masked_grids.min().item(), masked_grids.max().item())
        
        # Embed input
        x = self.embedding(masked_grids)  # This should now be (batch_size, num_grids, height, width, d_model)
        print(f"After embedding shape: {x.shape}")
        
        # Flatten grids
        x = x.view(batch_size, num_grids, -1, self.d_model)
        print(f"After flattening shape: {x.shape}")
        
        # Add trainable positional encoding
        x = self.pos_encoding(x)
        print(f"After positional encoding shape: {x.shape}")
        
        # Flatten the grid_masks to create attention mask
        # mask = masks.view(batch_size, num_grids, -1).sum(dim=-1) > 0
        mask = masks.view(batch_size, num_grids, -1).bool()  
        
        # Create attention mask for transformer
        # attn_mask = mask.unsqueeze(1) * mask.unsqueeze(2)
        seq_len = (self.num_train_pairs * 2 + 1) * self.input_dim * self.input_dim
        attn_mask = mask.view(batch_size, -1) 
        attn_mask = attn_mask.unsqueeze(1).expand(-1, seq_len, -1)
        attn_mask = ~attn_mask

        # Instead of expanding for all heads at once, iterate over each head
        attn_mask_final = []
        for _ in range(self.num_heads):
            attn_mask_final.append(attn_mask)
        attn_mask = torch.stack(attn_mask_final) # Shape: (num_heads, batch_size, seq_len, seq_len)
        attn_mask = attn_mask.reshape(self.num_heads * batch_size, seq_len, seq_len)
        
        # attn_mask = attn_mask.unsqueeze(0).expand(self.num_heads, -1, -1, -1) 
        # attn_mask = attn_mask.reshape(self.num_heads * batch_size, seq_len, seq_len)

        key_padding_mask = ~mask.view(batch_size, -1)
        print("key_padding_mask", key_padding_mask.shape)
        
        # Transformer encoder
        x = x.view(batch_size, -1, self.d_model)
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, d_model)
        print(f"Before transformer shape: {x.shape}")
        print(f"x shape: {x.shape}, dtype: {x.dtype}")
        print(f"key_padding_mask shape: {key_padding_mask.shape if key_padding_mask is not None else None}, dtype: {key_padding_mask.dtype if key_padding_mask is not None else None}")
        print(f"attn_mask shape: {attn_mask.shape if attn_mask is not None else None}, dtype: {attn_mask.dtype if attn_mask is not None else None}")
        x = self.transformer_encoder(x, src_key_padding_mask=key_padding_mask, mask=attn_mask)
        print(f"After transformer shape: {x.shape}")
        x = x.permute(1, 0, 2)  # (batch_size, seq_len, d_model)
        
        # Output layer
        x = self.output_layer(x)
        print(f"After output layer shape: {x.shape}")
        
        # Reshape to match the original grid shape
        x = x.view(batch_size, num_grids, height, width, self.num_classes)
        print(f"After reshaping shape: {x.shape}")
        
        # If we're in training mode and output is provided, use teacher forcing
        if self.training and output is not None:
            output_embedded = self.embedding(output)
            x[:, -1] = self.output_layer(output_embedded).view(batch_size, height, width, -1)
            # x[:, -1] = self.embedding(output).view(batch_size, height, width, -1)
        
        return x[:, -1]  # Return only the last grid (test output)
    

class TrainablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(TrainablePositionalEncoding, self).__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)
    
    def forward(self, x):
        seq_len = x.size(1) * x.size(2)  # num_grids * (height * width)
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(x.size(0), -1)
        pos_encodings = self.pos_embedding(positions).view(x.size(0), x.size(1), x.size(2), -1)
        return x + pos_encodings

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class GridAttention(nn.Module):
#     def __init__(self, d_model, nhead):
#         super().__init__()
#         self.self_attn = nn.MultiheadAttention(d_model, nhead)
    
#     def forward(self, x, mask):
#         return self.self_attn(x, x, x, key_padding_mask=mask)[0]

# class PairAttention(nn.Module):
#     def __init__(self, d_model, nhead):
#         super().__init__()
#         self.cross_attn = nn.MultiheadAttention(d_model, nhead)
    
#     def forward(self, x, y, mask_x, mask_y):
#         return self.cross_attn(x, y, y, key_padding_mask=mask_y)[0]

# class TransformerBlock(nn.Module):
#     def __init__(self, d_model, nhead, dim_feedforward):
#         super().__init__()
#         self.self_attn = GridAttention(d_model, nhead)
#         self.pair_attn = PairAttention(d_model, nhead)
#         self.feed_forward = nn.Sequential(
#             nn.Linear(d_model, dim_feedforward),
#             nn.ReLU(),
#             nn.Linear(dim_feedforward, d_model)
#         )
#         self.norm1 = nn.LayerNorm(d_model)
#         self.norm2 = nn.LayerNorm(d_model)
#         self.norm3 = nn.LayerNorm(d_model)
    
#     def forward(self, x, y, mask_x, mask_y):
#         x = x + self.self_attn(x, mask_x)
#         x = self.norm1(x)
#         x = x + self.pair_attn(x, y, mask_x, mask_y)
#         x = self.norm2(x)
#         x = x + self.feed_forward(x)
#         x = self.norm3(x)
#         return x

# class ARCTransformer(nn.Module):
#     def __init__(self, d_model, nhead, num_layers, dim_feedforward, max_grid_size, num_colors):
#         super().__init__()
#         self.d_model = d_model
#         self.embedding = nn.Embedding(num_colors + 1, d_model) # account for offset
#         self.pos_encoding = nn.Parameter(torch.randn(1, max_grid_size, max_grid_size, d_model))

        
#         self.input_layers = nn.ModuleList([TransformerBlock(d_model, nhead, dim_feedforward) for _ in range(num_layers)])
#         self.output_layers = nn.ModuleList([TransformerBlock(d_model, nhead, dim_feedforward) for _ in range(num_layers)])
#         self.cross_layers = nn.ModuleList([PairAttention(d_model, nhead) for _ in range(num_layers)])
        
#         self.final_layer = nn.Linear(d_model, num_colors)
    
#     def embed_grid(self, grid, mask):
#         print(f"Grid shape: {grid.shape}")
#         print(f"Mask shape: {mask.shape}")
#         print(f"Pos encoding shape: {self.pos_encoding.shape}")
        
#         embedded = self.embedding(grid)
#         print(f"Embedded shape: {embedded.shape}")
        
#         embedded = embedded + self.pos_encoding[:, :grid.shape[1], :grid.shape[2], :]
#         embedded = embedded.view(embedded.shape[0], -1, self.d_model)
#         mask = mask.view(mask.shape[0], -1)
#         return embedded, mask

    
#     def forward(self, batch):
#         batch_size = len(batch['task_id'])
#         num_train_pairs = 5  # Always 5 pairs, some might be padding

#         # Embed train input and output grids
#         train_inputs = self.embed_grid(batch['train_input_grids'].view(-1, 30, 30),
#                                     batch['train_input_masks'].view(-1, 30, 30))
#         train_outputs = self.embed_grid(batch['train_output_grids'].view(-1, 30, 30),
#                                         batch['train_output_masks'].view(-1, 30, 30))

#         # Reshape to (batch_size, num_train_pairs, 900, d_model)
#         train_inputs = train_inputs[0].view(batch_size, num_train_pairs, -1, self.d_model)
#         train_outputs = train_outputs[0].view(batch_size, num_train_pairs, -1, self.d_model)

#         # Create attention mask for train pairs
#         train_pair_mask = batch['train_input_output_grids_mask'].unsqueeze(-1).unsqueeze(-1)

#         # Process train pairs
#         for i in range(len(self.input_layers)):
#             train_inputs = self.input_layers[i](train_inputs, train_outputs, 
#                                                 batch['train_input_masks'], batch['train_output_masks'])
#             train_outputs = self.output_layers[i](train_outputs, train_inputs, 
#                                                 batch['train_output_masks'], batch['train_input_masks'])

#             # Apply mask to zero out padding pairs
#             train_inputs = train_inputs * train_pair_mask
#             train_outputs = train_outputs * train_pair_mask

#             # Cross-attention between pairs (within each batch item)
#             for j in range(num_train_pairs):
#                 for k in range(num_train_pairs):
#                     if j != k:
#                         train_inputs[:, j] = self.cross_layers[i](train_inputs[:, j], train_inputs[:, k], 
#                                                                 batch['train_input_masks'][:, j], batch['train_input_masks'][:, k])
#                         train_outputs[:, j] = self.cross_layers[i](train_outputs[:, j], train_outputs[:, k], 
#                                                                 batch['train_output_masks'][:, j], batch['train_output_masks'][:, k])

#         # Embed and process test input grid
#         test_input, test_mask = self.embed_grid(batch['test_input_grid'], batch['test_input_mask'])

#         # Process test input with attention to train pairs
#         for i in range(len(self.input_layers)):
#             test_input = self.input_layers[i](test_input, test_input, test_mask, test_mask)
#             for j in range(num_train_pairs):
#                 test_input = self.cross_layers[i](test_input, train_inputs[:, j], 
#                                                 test_mask, batch['train_input_masks'][:, j])
#                 test_input = self.cross_layers[i](test_input, train_outputs[:, j], 
#                                                 test_mask, batch['train_output_masks'][:, j])

#         # Generate output
#         output = self.final_layer(test_input)
#         output = output.view(batch['test_input_grid'].shape)

#         return output
    


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

def masked_cross_entropy_loss(predictions, targets, mask):
    B, H, W, C = predictions.shape
    predictions = predictions.contiguous().view(B*H*W, C)
    targets = targets.contiguous().view(B*H*W)
    mask = mask.contiguous().view(B*H*W)
    
    loss = nn.CrossEntropyLoss(reduction='none')(predictions, targets)
    masked_loss = (loss * mask.float()).sum() / mask.float().sum()
    return masked_loss

def train_arc_transformer(model, train_loader, val_loader, num_epochs, learning_rate, device):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    history = {'train_loss': [], 'val_loss': []}
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        for i, batch in enumerate(train_loader):
            grids, grid_masks, output_grid = [item.to(device) for item in batch]
            
            optimizer.zero_grad()

            print("before predictions", grids.shape, grid_masks.shape)
            if output_grid is not None:
                print("output shape", output_grid.shape)
            
            predictions = model(grids, grid_masks, output_grid)

            print("predictions", predictions.shape)
            
            loss = masked_cross_entropy_loss(predictions, output_grid, grid_masks[:, -1])
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # if (i + 1) % 10 == 0:  # Print every 10 batches
            print(f'Epoch {epoch+1}/{num_epochs}, Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}')
        
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch in val_loader:
                grids, grid_masks, output_grid = [item.to(device) for item in batch]
                
                predictions = model(grids, grid_masks)
                
                loss = masked_cross_entropy_loss(predictions, output_grid, grid_masks[:, -1])
                
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        history['val_loss'].append(val_loss)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    
    return model, history




In [None]:
from torch.utils.data import DataLoader
from arc_prize.data import ARCDataset, ARCDatasetConfig, collate_arc_fn

# Hyperparameters
d_model = 32
num_layers = 4
dim_feedforward = 128
max_grid_size = 10 # 30
num_heads = 4
max_context_pairs = 5 # 10
batch_size = 5
num_epochs = 5
num_colors = 10
learning_rate = 1e-4

synth_arc_dataset_config = ARCDatasetConfig(max_grid_size=max_grid_size, max_train_grids=max_context_pairs, color_offset=1)

train_dataset = ARCDataset("data/arc-synth_move_right_training_challenges.json", "data/arc-synth_move_right_training_solutions.json", config=synth_arc_dataset_config)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_arc_fn, num_workers=0)

val_dataset = ARCDataset("data/arc-synth_move_right_evaluation_challenges.json", "data/arc-synth_move_right_evaluation_solutions.json", config=synth_arc_dataset_config)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_arc_fn, num_workers=0)

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ARCTransformer(d_model=d_model, num_heads=num_heads, num_layers=num_layers, d_ff=dim_feedforward, grid_dim=max_grid_size, num_colors=num_colors, num_train_pairs=max_context_pairs).to(device)

# Train the model
# train_arc_transformer(model=model, train_loader=train_loader, val_loader=val_loader, num_epochs=num_epochs, learning_rate=learning_rate, device=device)

model_file_name = "arc_synth_transformer_model.pth"
# Save the trained model
torch.save(model.state_dict(), model_file_name)

print("Training completed and model saved.")

In [None]:

# arc_dataset_config = ARCDatasetConfig(max_grid_size=max_grid_size, max_train_grids=max_context_pairs, color_offset=1)

# train_dataset = ARCDataset("data/arc-agi_training_challenges.json", "data/arc-agi_training_solutions.json", config=arc_dataset_config)
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_arc_fn, num_workers=0)

# val_dataset = ARCDataset("data/arc-agi_evaluation_challenges.json", "data/arc-agi_evaluation_solutions.json", config=arc_dataset_config)
# val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_arc_fn, num_workers=0)

# # Initialize model
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = ARCTransformer(d_model=d_model, num_heads=num_heads, num_layers=num_layers, d_ff=dim_feedforward, grid_dim=max_grid_size, num_colors=num_colors, num_train_pairs=max_context_pairs).to(device)

# # Train the model
# train_arc_transformer(model=model, train_loader=train_loader, val_loader=val_loader, num_epochs=num_epochs, learning_rate=learning_rate, device=device)

# model_file_name = "arc_transformer_model.pth"
# # Save the trained model
# torch.save(model.state_dict(), model_file_name)

# print("Training completed and model saved.")