In [1]:
import json
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [2]:
@dataclass(frozen=True)
class ARCDatasetParams:
    max_grid_size: int = 30
    max_train_grids: int = 10
    color_offset: int = 1


def pad_and_mask_grid(
    grid: list[list[int]], config: ARCDatasetParams
) -> tuple[torch.Tensor, torch.Tensor]:
    h, w = len(grid), len(grid[0])
    if h > config.max_grid_size or w > config.max_grid_size:
        raise Exception("grid size too large")

    padded = torch.zeros((config.max_grid_size, config.max_grid_size), dtype=torch.int)
    mask = torch.zeros((config.max_grid_size, config.max_grid_size), dtype=torch.bool)

    # Calculate padding
    pad_h = (config.max_grid_size - h) // 2
    pad_w = (config.max_grid_size - w) // 2

    # Place the grid in the center
    padded[pad_h : pad_h + h, pad_w : pad_w + w] = (
        torch.tensor(grid, dtype=torch.int) + config.color_offset
    )
    mask[pad_h : pad_h + h, pad_w : pad_w + w] = True

    return (padded, mask)


class ARCKaggleDataset(Dataset):
    challenges: dict[str, dict]
    task_ids: list[str]
    config: ARCDatasetParams

    def __init__(
        self,
        challenges_file: str,
        config: ARCDatasetParams,
    ):
        with open(challenges_file, "r") as f:
            self.challenges = json.load(f)
            self.task_ids = list(self.challenges.keys())
        self.config = config

    def __len__(self) -> int:
        return len(self.task_ids)

    def __getitem__(self, idx: int) -> dict:
        task_id = self.task_ids[idx]
        challenge = self.challenges[task_id]

        all_grids = []
        all_masks = []

        for test in challenge["test"]:
            grids = torch.zeros(
                2 * self.config.max_train_grids + 1,
                self.config.max_grid_size,
                self.config.max_grid_size,
                dtype=torch.int,
            )
            masks = torch.zeros(
                2 * self.config.max_train_grids + 1,
                self.config.max_grid_size,
                self.config.max_grid_size,
                dtype=torch.bool,
            )

            for i, pair in enumerate(challenge["train"]):
                if i >= self.config.max_train_grids:
                    print(
                        "Training pairs exceed max", task_id, i, self.config.max_train_grids
                    )
                    break

                try:
                    input_grid, input_mask = pad_and_mask_grid(pair["input"], self.config)
                    output_grid, output_mask = pad_and_mask_grid(
                        pair["output"], self.config
                    )
                    grids[2 * i] = input_grid
                    masks[2 * i] = input_mask
                    grids[2 * i + 1] = output_grid
                    masks[2 * i + 1] = output_mask
                except Exception as e:
                    print("Got exception for training pair", task_id, i, e)

            try:
                test_input_grid, test_input_mask = pad_and_mask_grid(
                    test["input"], self.config
                )
                grids[-1] = test_input_grid
                masks[-1] = test_input_mask
            except Exception as e:
                print("Got exception on test input", task_id, e)
            
            all_grids.append(grids)
            all_masks.append(masks)

        return {"task_id": task_id, "grids": torch.stack(all_grids), "masks": torch.stack(all_masks)}
    

def collate_arc_fn(
    batch: list[dict],
) -> tuple[list, torch.Tensor, torch.Tensor]:
    task_ids = [item["task_id"] for item in batch]
    grids = torch.stack([item["grids"] for item in batch])
    masks = torch.stack([item["masks"] for item in batch])

    return (task_ids, grids, masks)

In [3]:
@dataclass(frozen=True)
class ARCTransformerEncoderDecoderParams:
    grid_dim: int
    num_train_pairs: int
    num_colors: int
    num_encoder_layers: int
    num_decoder_layers: int
    num_heads: int
    d_model: int
    d_ff: int
    dropout: float



class ARCPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, grid_dim: int, num_train_pairs: int):
        super().__init__()
        self.d_model = d_model
        self.grid_dim = grid_dim
        self.num_train_pairs = num_train_pairs

        # Embeddings for row and column positions
        self.row_embedding = nn.Embedding(self.grid_dim, self.d_model // 4)
        self.col_embedding = nn.Embedding(self.grid_dim, self.d_model // 4)

        # Embedding for input vs output
        self.input_output_embedding = nn.Embedding(2, d_model // 4)

        # Embedding for training pair index
        self.pair_embedding = nn.Embedding(
            self.num_train_pairs + 1, d_model // 4
        )  # +1 for test pair

    def forward(self, x):
        batch_size, num_grids, height, width, _ = x.size()

        row_pos = torch.arange(height, device=x.device).unsqueeze(1).expand(-1, width)
        col_pos = torch.arange(width, device=x.device).unsqueeze(0).expand(height, -1)

        row_emb = self.row_embedding(row_pos)
        col_emb = self.col_embedding(col_pos)

        pos_emb = torch.cat([row_emb, col_emb], dim=-1)

        pos_emb = (
            pos_emb.unsqueeze(0).unsqueeze(0).expand(batch_size, num_grids, -1, -1, -1)
        )

        grid_indices = (
            torch.arange(num_grids, device=x.device).unsqueeze(0).expand(batch_size, -1)
        )

        is_output = (grid_indices % 2 == 1).long()
        io_emb = self.input_output_embedding(is_output)

        pair_indices = torch.div(grid_indices, 2, rounding_mode="floor")

        pair_indices[:, -1] = self.num_train_pairs
        pair_emb = self.pair_embedding(pair_indices)

        io_emb = io_emb.unsqueeze(2).unsqueeze(2).expand(-1, -1, height, width, -1)
        pair_emb = pair_emb.unsqueeze(2).unsqueeze(2).expand(-1, -1, height, width, -1)

        combined_emb = torch.cat([pos_emb, io_emb, pair_emb], dim=-1)

        return x + combined_emb


class EncoderLayerWithAttention(nn.TransformerEncoderLayer):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ) -> None:
        super().__init__(d_model, nhead, dim_feedforward, dropout, batch_first=True)

    def forward(
        self,
        src: torch.Tensor,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(src_mask),
            other_name="src_mask",
            target_type=src.dtype,
        )

        src_mask = F._canonical_mask(
            mask=src_mask,
            mask_name="src_mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        x = src
        x1, attn_weights = self.self_attn(
            x,
            x,
            x,
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask,
            need_weights=need_weights,
            is_causal=is_causal,
            average_attn_weights=False,
        )
        x = self.norm1(x + self.dropout1(x1))
        x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))

        x = self.norm2(x + self.dropout2(x2))

        return x, attn_weights


class EncoderWithAttention(nn.TransformerEncoder):
    def __init__(
        self,
        encoder_layer: "EncoderLayerWithAttention",
        num_layers: int,
    ) -> None:
        super().__init__(encoder_layer, num_layers)

    def forward(
        self,
        src: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(mask),
            other_name="mask",
            target_type=src.dtype,
        )

        mask = F._canonical_mask(
            mask=mask,
            mask_name="mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        output = src
        attn_weights = []

        for mod in self.layers:
            output, layer_attn_weights = mod(
                output,
                src_mask=mask,
                src_key_padding_mask=src_key_padding_mask,
                need_weights=need_weights,
            )
            if need_weights:
                attn_weights.append(layer_attn_weights)

        if self.norm is not None:
            output = self.norm(output)

        return output, (torch.stack(attn_weights, dim=1) if need_weights else None)


class DecoderLayerWithAttention(nn.TransformerDecoderLayer):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ) -> None:
        super().__init__(d_model, nhead, dim_feedforward, dropout, batch_first=True)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
        tgt_key_padding_mask: Optional[torch.Tensor] = None,
        memory_key_padding_mask: Optional[torch.Tensor] = None,
        tgt_is_causal: bool = False,
        memory_is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        x = tgt
        x_sa, sa_attn_weights = self._sa_block(
            x,
            tgt_mask,
            tgt_key_padding_mask,
            is_causal=tgt_is_causal,
            need_weights=need_weights,
        )
        x = self.norm1(x + x_sa)

        x_mha, mha_attn_weights = self._mha_block(
            x,
            memory,
            memory_mask,
            memory_key_padding_mask,
            is_causal=memory_is_causal,
            need_weights=need_weights,
        )
        x = self.norm2(x + x_mha)
        x = self.norm3(x + self._ff_block(x))

        return x, sa_attn_weights, mha_attn_weights

    # self-attention block
    def _sa_block(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
        is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        x, sa_attn_weights = self.self_attn(
            x,
            x,
            x,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            need_weights=need_weights,
            average_attn_weights=False,
        )
        return self.dropout1(x), sa_attn_weights

    # multihead attention block
    def _mha_block(
        self,
        x: torch.Tensor,
        mem: torch.Tensor,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
        is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        x, mha_attn_weights = self.multihead_attn(
            x,
            mem,
            mem,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            need_weights=need_weights,
            average_attn_weights=False,
        )
        return self.dropout2(x), mha_attn_weights

    # feed forward block
    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)


class DecoderWithAttention(nn.TransformerDecoder):
    def __init__(
        self,
        decoder_layer: "DecoderLayerWithAttention",
        num_layers: int,
    ) -> None:
        super().__init__(decoder_layer, num_layers)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
        tgt_key_padding_mask: Optional[torch.Tensor] = None,
        memory_key_padding_mask: Optional[torch.Tensor] = None,
        tgt_is_causal: Optional[bool] = False,
        memory_is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        output = tgt

        sa_attn_weights = []
        mha_attn_weights = []

        for mod in self.layers:
            output, layer_sa_attn_weights, layer_mha_attn_weights = mod(
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                tgt_is_causal=tgt_is_causal,
                memory_is_causal=memory_is_causal,
                need_weights=need_weights,
            )
            if need_weights:
                sa_attn_weights.append(layer_sa_attn_weights)
                mha_attn_weights.append(layer_mha_attn_weights)

        if self.norm is not None:
            output = self.norm(output)

        return (
            output,
            torch.stack(sa_attn_weights, dim=1) if need_weights else None,
            torch.stack(mha_attn_weights, dim=1) if need_weights else None,
        )


class ARCTransformerEncoderDecoder(nn.Module):
    grid_dim: int
    num_train_pairs: int
    num_classes: int
    num_encoder_layers: int
    num_decoder_layers: int
    num_heads: int
    d_model: int
    d_ff: int
    dropout: float
    seq_len: int

    def __init__(self, params: ARCTransformerEncoderDecoderParams):
        super().__init__()
        self.grid_dim = params.grid_dim
        self.num_train_pairs = params.num_train_pairs
        self.num_classes = params.num_colors + 1
        self.d_model = params.d_model
        self.num_encoder_layers = params.num_encoder_layers
        self.num_decoder_layers = params.num_decoder_layers
        self.num_heads = params.num_heads
        self.d_ff = params.d_ff
        self.dropout = params.dropout
        self.seq_len = (self.num_train_pairs * 2 + 1) * self.grid_dim * self.grid_dim

        self.embedding = nn.Embedding(self.num_classes, self.d_model)
        self.pos_encoding = ARCPositionalEncoding(
            d_model=self.d_model,
            grid_dim=self.grid_dim,
            num_train_pairs=self.num_train_pairs,
        )

        encoder_layer = EncoderLayerWithAttention(
            self.d_model, self.num_heads, self.d_ff, self.dropout
        )

        self.encoder = EncoderWithAttention(encoder_layer, self.num_encoder_layers)

        decoder_layer = DecoderLayerWithAttention(
            self.d_model, self.num_heads, self.d_ff, self.dropout
        )

        self.decoder = DecoderWithAttention(decoder_layer, self.num_decoder_layers)

        self.output_query = nn.Parameter(torch.randn(1, self.grid_dim**2, self.d_model))
        self.output_layer = nn.Linear(self.d_model, self.num_classes)

    def forward(
        self, src: torch.Tensor, src_mask: torch.Tensor, need_weights: bool = False
    ) -> tuple[
        torch.Tensor,
        Optional[torch.Tensor],
        Optional[torch.Tensor],
        Optional[torch.Tensor],
    ]:
        batch_size = src.shape[0]

        src = self.embedding(src)

        src = self.pos_encoding(src)

        src = src.view(batch_size, self.seq_len, self.d_model)

        padding_mask = ~src_mask.view(batch_size, -1)

        memory, encoder_attn_weights = self.encoder.forward(
            src, src_key_padding_mask=padding_mask, need_weights=need_weights
        )

        output_query = self.output_query.expand(batch_size, -1, -1)

        (
            output,
            decoder_sa_attn_weights,
            decoder_mha_attn_weights,
        ) = self.decoder.forward(
            output_query,
            memory,
            memory_key_padding_mask=padding_mask,
            need_weights=need_weights,
        )

        output = self.output_layer(output)

        output = output.view(batch_size, self.grid_dim, self.grid_dim, self.num_classes)

        return (
            output,
            encoder_attn_weights,
            decoder_sa_attn_weights,
            decoder_mha_attn_weights,
        )

    def generate(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        need_weights: bool = False,
    ) -> tuple[
        torch.Tensor,
        Optional[torch.Tensor],
        Optional[torch.Tensor],
        Optional[torch.Tensor],
    ]:
        with torch.no_grad():
            (
                output,
                encoder_attn_weights,
                decoder_sa_attn_weights,
                decoder_mha_attn_weights,
            ) = self.forward(src, src_mask, need_weights)
            return (
                torch.argmax(output, dim=-1),
                encoder_attn_weights,
                decoder_sa_attn_weights,
                decoder_mha_attn_weights,
            )

@dataclass
class ARCKaggleModelState:
    model_params: ARCTransformerEncoderDecoderParams
    model_state_dict: dict



In [24]:
# Load dataset
test_challenges_file_name = "data/arc-agi_evaluation_challenges.json"
config = ARCDatasetParams(max_grid_size=12, max_train_grids=4, color_offset=1)
test_dataset = ARCKaggleDataset(challenges_file=test_challenges_file_name, config=config)

In [25]:
kaggle_model_file_path = "kaggle/models/subtly_known_panda.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_params = ARCTransformerEncoderDecoderParams(grid_dim=12, num_train_pairs=4, num_colors=10, num_encoder_layers=6, num_decoder_layers=6, num_heads=16, d_model=512, d_ff=2048, dropout=0.3)
model = ARCTransformerEncoderDecoder(model_params).to(device)
model.load_state_dict(torch.load(kaggle_model_file_path, map_location=device))

model.eval()

def unpad_output(output: torch.Tensor) -> torch.Tensor:
    output = output - 1
    filtered_rows = []
    for row in output:
        filtered_row = row[row != -1]
        if len(filtered_row) > 0:
            filtered_rows.append(filtered_row)
    # Hack to ensure there's always at least one value
    if len(filtered_rows) == 0:
        filtered_rows.append(torch.zeros(1, device=device, dtype=torch.long))
    max_length = max(len(row) for row in filtered_rows)
    padded_rows = [torch.cat([row, torch.zeros(max_length - len(row), dtype=row.dtype, device=device)]) for row in filtered_rows]
    return torch.stack(padded_rows)

kaggle_output = {}

for task in test_dataset:
    print(f"Starting {task['task_id']}")
    task_predictions = []
    for grids, masks in zip(task["grids"], task["masks"]):
        prediction = model.generate(grids.unsqueeze(0), masks.unsqueeze(0), need_weights=False)[0][0]
        list_prediction = unpad_output(prediction).cpu().numpy().tolist()
        task_predictions.append({
            "attempt_1":  list_prediction,
            "attempt_2": list_prediction
        })
    kaggle_output[task["task_id"]] = task_predictions


Starting 00576224
Got exception for training pair 009d5c81 0 grid size too large
Got exception for training pair 009d5c81 1 grid size too large
Got exception for training pair 009d5c81 2 grid size too large
Got exception for training pair 009d5c81 3 grid size too large
Training pairs exceed max 009d5c81 4 4
Got exception on test input 009d5c81 grid size too large
Starting 009d5c81
Got exception for training pair 00dbd492 2 grid size too large
Got exception for training pair 00dbd492 3 grid size too large
Got exception on test input 00dbd492 grid size too large
Starting 00dbd492
Starting 03560426
Got exception for training pair 05a7bcf2 0 grid size too large
Got exception for training pair 05a7bcf2 1 grid size too large
Got exception for training pair 05a7bcf2 2 grid size too large
Got exception on test input 05a7bcf2 grid size too large
Starting 05a7bcf2
Got exception for training pair 0607ce86 0 grid size too large
Got exception for training pair 0607ce86 1 grid size too large
Got exc

In [None]:

submission_file_path = "kaggle/submission.json"

with open(submission_file_path, "w") as f:
  json.dump(kaggle_output, f)

FileNotFoundError: [Errno 2] No such file or directory: 'kaggle/working/submission.json'

In [26]:
solutions_file_path = "data/arc-agi_evaluation_solutions.json"
with open(solutions_file_path, "r") as f:
  solutions = json.load(f)
  total = 0
  correct = 0
  for task_id, outputs in solutions.items():
    predictions = kaggle_output[task_id]
    for prediction, output in zip(predictions, outputs):
      total += 1
      if prediction["attempt_1"] == output or prediction["attempt_2"] == output:
        correct += 1
      

In [27]:
print(total, correct)

419 6
