In [1]:
import json
import copy
import itertools
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler

In [2]:
test_challenges_file_name = "data/arc-agi_evaluation_challenges.json"
kaggle_model_file_path = "kaggle/models/kindly_exact_beagle_4.pth"
kaggle_submission_file_path = "kaggle/submission.json"

In [3]:
@dataclass(frozen=True)
class ARCDatasetParams:
    max_grid_size: int = 30
    max_train_grids: int = 10
    color_offset: int = 1


def pad_and_mask_grid(
    grid: list[list[int]], config: ARCDatasetParams
) -> tuple[torch.Tensor, torch.Tensor]:
    h, w = len(grid), len(grid[0])
    
    # Calculate how much of the grid we can use
    h_to_use = min(h, config.max_grid_size)
    w_to_use = min(w, config.max_grid_size)
    
    # If grid is too big, take the center portion
    if h > config.max_grid_size or w > config.max_grid_size:
        print("Warning: grid size too large")
        h_start = (h - h_to_use) // 2
        w_start = (w - w_to_use) // 2
        grid = [row[w_start:w_start + w_to_use] for row in grid[h_start:h_start + h_to_use]]

    # Now h_to_use, w_to_use are our actual grid dimensions
    h, w = h_to_use, w_to_use

    padded = torch.zeros((config.max_grid_size, config.max_grid_size), dtype=torch.int)
    mask = torch.zeros((config.max_grid_size, config.max_grid_size), dtype=torch.bool)

    # Calculate padding for the portion we're using
    pad_h = (config.max_grid_size - h) // 2
    pad_w = (config.max_grid_size - w) // 2

    # Place the grid in the center
    padded[pad_h : pad_h + h, pad_w : pad_w + w] = (
        torch.tensor(grid, dtype=torch.int) + config.color_offset
    )
    mask[pad_h : pad_h + h, pad_w : pad_w + w] = True

    return (padded, mask)


class ARCKaggleDataset(Dataset):
    challenges: dict[str, dict]
    task_ids: list[str]
    config: ARCDatasetParams

    def __init__(
        self,
        challenges_file: str,
        config: ARCDatasetParams,
    ):
        with open(challenges_file, "r") as f:
            self.challenges = json.load(f)
            self.task_ids = list(self.challenges.keys())
        self.config = config

    def __len__(self) -> int:
        return len(self.task_ids)

    def __getitem__(self, idx: int) -> dict:
        task_id = self.task_ids[idx]
        challenge = self.challenges[task_id]

        all_grids = []
        all_masks = []

        for test in challenge["test"]:
            grids = torch.zeros(
                2 * self.config.max_train_grids + 1,
                self.config.max_grid_size,
                self.config.max_grid_size,
                dtype=torch.int,
            )
            masks = torch.zeros(
                2 * self.config.max_train_grids + 1,
                self.config.max_grid_size,
                self.config.max_grid_size,
                dtype=torch.bool,
            )

            for i, pair in enumerate(challenge["train"]):
                if i >= self.config.max_train_grids:
                    print(
                        "Training pairs exceed max", task_id, i, self.config.max_train_grids
                    )
                    break

                try:
                    input_grid, input_mask = pad_and_mask_grid(pair["input"], self.config)
                    output_grid, output_mask = pad_and_mask_grid(
                        pair["output"], self.config
                    )
                    grids[2 * i] = input_grid
                    masks[2 * i] = input_mask
                    grids[2 * i + 1] = output_grid
                    masks[2 * i + 1] = output_mask
                except Exception as e:
                    print("Got exception for training pair", task_id, i, e)

            try:
                test_input_grid, test_input_mask = pad_and_mask_grid(
                    test["input"], self.config
                )
                grids[-1] = test_input_grid
                masks[-1] = test_input_mask
            except Exception as e:
                print("Got exception on test input", task_id, e)
            
            all_grids.append(grids)
            all_masks.append(masks)

        return {"task_id": task_id, "grids": torch.stack(all_grids), "masks": torch.stack(all_masks)}
    

def collate_arc_fn(
    batch: list[dict],
) -> tuple[list, torch.Tensor, torch.Tensor]:
    task_ids = [item["task_id"] for item in batch]
    grids = torch.stack([item["grids"] for item in batch])
    masks = torch.stack([item["masks"] for item in batch])

    return (task_ids, grids, masks)

In [4]:
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass(frozen=True)
class ARCTransformerEncoderDecoderParams:
    grid_dim: int
    num_train_pairs: int
    num_colors: int
    num_encoder_layers: int
    num_decoder_layers: int
    num_heads: int
    d_model: int
    d_ff: int
    dropout: float


class EncoderLayerWithAttention(nn.TransformerEncoderLayer):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ) -> None:
        super().__init__(d_model, nhead, dim_feedforward, dropout, batch_first=True)

    def forward(
        self,
        src: torch.Tensor,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(src_mask),
            other_name="src_mask",
            target_type=src.dtype,
        )

        src_mask = F._canonical_mask(
            mask=src_mask,
            mask_name="src_mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        x = src
        x1, attn_weights = self.self_attn(
            x,
            x,
            x,
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask,
            need_weights=need_weights,
            is_causal=is_causal,
            average_attn_weights=False,
        )

        x = x + self.dropout1(x1)
        x = self.norm1(x)

        x1 = self.linear2(self.dropout(self.activation(self.linear1(x))))
        x = x + self.dropout2(x1)
        x = self.norm2(x)

        return x, attn_weights


class EncoderWithAttention(nn.TransformerEncoder):
    def __init__(
        self,
        encoder_layer: "EncoderLayerWithAttention",
        num_layers: int,
    ) -> None:
        super().__init__(encoder_layer, num_layers)

    def forward(
        self,
        src: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(mask),
            other_name="mask",
            target_type=src.dtype,
        )

        mask = F._canonical_mask(
            mask=mask,
            mask_name="mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        output = src
        attn_weights = []

        for mod in self.layers:
            output, layer_attn_weights = mod(
                output,
                src_mask=mask,
                src_key_padding_mask=src_key_padding_mask,
                need_weights=need_weights,
            )
            if need_weights:
                attn_weights.append(layer_attn_weights)

        if self.norm is not None:
            output = self.norm(output)

        return output, (torch.stack(attn_weights, dim=1) if need_weights else None)


class DecoderLayerWithAttention(nn.TransformerDecoderLayer):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ) -> None:
        super().__init__(d_model, nhead, dim_feedforward, dropout, batch_first=True)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
        tgt_key_padding_mask: Optional[torch.Tensor] = None,
        memory_key_padding_mask: Optional[torch.Tensor] = None,
        tgt_is_causal: bool = False,
        memory_is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        x = tgt
        x_sa, sa_attn_weights = self._sa_block(
            x,
            tgt_mask,
            tgt_key_padding_mask,
            is_causal=tgt_is_causal,
            need_weights=need_weights,
        )
        x = x + x_sa
        x = self.norm1(x)

        x_mha, mha_attn_weights = self._mha_block(
            x,
            memory,
            memory_mask,
            memory_key_padding_mask,
            is_causal=memory_is_causal,
            need_weights=need_weights,
        )
        x = x + x_mha
        x = self.norm2(x)

        x = x + self._ff_block(x)
        x = self.norm3(x)

        return x, sa_attn_weights, mha_attn_weights

    # self-attention block
    def _sa_block(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
        is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        x, sa_attn_weights = self.self_attn(
            x,
            x,
            x,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            need_weights=need_weights,
            average_attn_weights=False,
        )
        return self.dropout1(x), sa_attn_weights

    # multihead attention block
    def _mha_block(
        self,
        x: torch.Tensor,
        mem: torch.Tensor,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
        is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        x, mha_attn_weights = self.multihead_attn(
            x,
            mem,
            mem,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            need_weights=need_weights,
            average_attn_weights=False,
        )
        return self.dropout2(x), mha_attn_weights

    # feed forward block
    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)


class DecoderWithAttention(nn.TransformerDecoder):
    def __init__(
        self,
        decoder_layer: "DecoderLayerWithAttention",
        num_layers: int,
    ) -> None:
        super().__init__(decoder_layer, num_layers)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
        tgt_key_padding_mask: Optional[torch.Tensor] = None,
        memory_key_padding_mask: Optional[torch.Tensor] = None,
        tgt_is_causal: Optional[bool] = False,
        memory_is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        output = tgt

        sa_attn_weights = []
        mha_attn_weights = []

        for mod in self.layers:
            output, layer_sa_attn_weights, layer_mha_attn_weights = mod(
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                tgt_is_causal=tgt_is_causal,
                memory_is_causal=memory_is_causal,
                need_weights=need_weights,
            )
            if need_weights:
                sa_attn_weights.append(layer_sa_attn_weights)
                mha_attn_weights.append(layer_mha_attn_weights)

        if self.norm is not None:
            output = self.norm(output)

        return (
            output,
            torch.stack(sa_attn_weights, dim=1) if need_weights else None,
            torch.stack(mha_attn_weights, dim=1) if need_weights else None,
        )


class ARCTransformerEncoderDecoder(nn.Module):
    grid_dim: int
    num_train_pairs: int
    num_classes: int
    num_encoder_layers: int
    num_decoder_layers: int
    num_heads: int
    d_model: int
    d_ff: int
    dropout: float
    seq_len: int

    def __init__(self, params: ARCTransformerEncoderDecoderParams):
        super().__init__()
        self.grid_dim = params.grid_dim
        self.num_train_pairs = params.num_train_pairs
        self.num_classes = params.num_colors + 1
        self.d_model = params.d_model
        self.num_encoder_layers = params.num_encoder_layers
        self.num_decoder_layers = params.num_decoder_layers
        self.num_heads = params.num_heads
        self.d_ff = params.d_ff
        self.dropout = params.dropout
        self.seq_len = (self.num_train_pairs * 2 + 1) * self.grid_dim * self.grid_dim

        self.embedding = nn.Embedding(self.num_classes, self.d_model)
        self.pos_encoding = ARCPositionalEncoding(
            d_model=self.d_model,
            grid_dim=self.grid_dim,
            num_train_pairs=self.num_train_pairs,
        )

        encoder_layer = EncoderLayerWithAttention(
            self.d_model, self.num_heads, self.d_ff, self.dropout
        )

        self.encoder = EncoderWithAttention(encoder_layer, self.num_encoder_layers)

        decoder_layer = DecoderLayerWithAttention(
            self.d_model, self.num_heads, self.d_ff, self.dropout
        )

        self.decoder = DecoderWithAttention(decoder_layer, self.num_decoder_layers)

        self.output_query = nn.Parameter(torch.randn(1, self.grid_dim**2, self.d_model))
        self.output_layer = nn.Linear(self.d_model, self.num_classes)

    def forward(
        self, src: torch.Tensor, src_mask: torch.Tensor, need_weights: bool = False
    ) -> tuple[
        torch.Tensor,
        Optional[torch.Tensor],
        Optional[torch.Tensor],
        Optional[torch.Tensor],
    ]:
        batch_size = src.shape[0]

        src = self.embedding.forward(src)

        pos_emb = self.pos_encoding.forward(src)
        src.add_(pos_emb)

        src = src.view(batch_size, self.seq_len, self.d_model)

        padding_mask = ~src_mask.view(batch_size, -1)

        memory, encoder_attn_weights = self.encoder.forward(
            src, src_key_padding_mask=padding_mask, need_weights=need_weights
        )

        output_query = self.output_query.expand(batch_size, -1, -1)

        (
            output,
            decoder_sa_attn_weights,
            decoder_mha_attn_weights,
        ) = self.decoder.forward(
            output_query,
            memory,
            memory_key_padding_mask=padding_mask,
            need_weights=need_weights,
        )

        output = self.output_layer(output)

        output = output.view(batch_size, self.grid_dim, self.grid_dim, self.num_classes)

        return (
            output,
            encoder_attn_weights,
            decoder_sa_attn_weights,
            decoder_mha_attn_weights,
        )

    def generate(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        need_weights: bool = False,
    ) -> tuple[
        torch.Tensor,
        Optional[torch.Tensor],
        Optional[torch.Tensor],
        Optional[torch.Tensor],
    ]:
        with torch.no_grad():
            (
                output,
                encoder_attn_weights,
                decoder_sa_attn_weights,
                decoder_mha_attn_weights,
            ) = self.forward(src, src_mask, need_weights)
            return (
                torch.argmax(output, dim=-1),
                encoder_attn_weights,
                decoder_sa_attn_weights,
                decoder_mha_attn_weights,
            )


class ARCTransformerEncoder(nn.Module):
    grid_dim: int
    num_train_pairs: int
    num_classes: int
    num_layers: int
    num_heads: int
    d_model: int
    d_ff: int
    dropout: float
    seq_len: int

    def __init__(self, params: ARCTransformerEncoderDecoderParams):
        super().__init__()
        self.grid_dim = params.grid_dim
        self.num_train_pairs = params.num_train_pairs
        self.num_classes = params.num_colors + 1
        self.d_model = params.d_model
        self.num_layers = params.num_encoder_layers
        self.num_heads = params.num_heads
        self.d_ff = params.d_ff
        self.dropout = params.dropout

        self.input_seq_len = (
            (self.num_train_pairs * 2 + 1) * self.grid_dim * self.grid_dim
        )
        self.output_seq_len = self.grid_dim * self.grid_dim
        self.seq_len = self.input_seq_len + self.output_seq_len

        self.embedding = nn.Embedding(self.num_classes, self.d_model)
        self.pos_encoding = ARCPositionalEncoding(
            d_model=self.d_model,
            grid_dim=self.grid_dim,
            num_train_pairs=self.num_train_pairs,
        )

        encoder_layer = EncoderLayerWithAttention(
            self.d_model, self.num_heads, self.d_ff, self.dropout
        )

        self.encoder = EncoderWithAttention(encoder_layer, self.num_layers)
        # encoder_layer = nn.TransformerEncoderLayer(
        #     self.d_model, self.num_heads, self.d_ff, self.dropout, batch_first=True
        # )
        # self.encoder = nn.TransformerEncoder(encoder_layer, self.num_layers)
        self.output_query = nn.Parameter(
            torch.randn(1, 1, self.grid_dim, self.grid_dim, self.d_model)
        )
        self.output_layer = nn.Linear(self.d_model, self.num_classes)

    def forward(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        tgt: Optional[torch.Tensor] = None,
        temperature: float = 0.0,
    ) -> tuple[
        torch.Tensor,
        Optional[torch.Tensor],
        Optional[torch.Tensor],
        Optional[torch.Tensor],
    ]:
        batch_size = src.shape[0]

        embedded_src = self.embedding.forward(src)

        if tgt is not None:
            output_query = self.embedding.forward(tgt).view(
                batch_size, 1, self.grid_dim, self.grid_dim, self.d_model
            )
        else:
            output_query = self.output_query.expand(batch_size, -1, -1, -1, -1)

        combined_input = torch.cat([embedded_src, output_query], dim=1)

        # Add positional encodings
        pos_emb = self.pos_encoding(combined_input)
        embedded = combined_input + pos_emb

        embedded = embedded.view(batch_size, self.seq_len, self.d_model)

        causal_mask = torch.zeros(self.seq_len, self.seq_len, device=src.device)
        causal_mask[: self.input_seq_len, self.input_seq_len :] = 1
        causal_mask = causal_mask.bool()

        # Create padding mask
        padding_mask = ~src_mask.view(batch_size, -1)

        padding_mask = torch.cat(
            [
                padding_mask,
                torch.zeros(
                    (batch_size, self.grid_dim**2), dtype=torch.bool, device=src.device
                ),
            ],
            dim=1,
        )

        output = self.encoder.forward(
            embedded, mask=causal_mask, src_key_padding_mask=padding_mask
        )[0]

        # Get only the output grid portion
        output_grid_portion = output[:, -self.output_seq_len :, :]

        # Project to vocabulary space
        logits = self.output_layer(output_grid_portion)

        # Reshape to grid format
        output = logits.view(batch_size, self.grid_dim, self.grid_dim, self.num_classes)

        if temperature > 0:
            output = output / temperature

        return (output, None, None, None)

    def generate(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        tgt: Optional[torch.Tensor] = None,
        temperature: float = 0.0,
        need_weights: bool = False,
    ) -> tuple[
        torch.Tensor,
        Optional[torch.Tensor],
        Optional[torch.Tensor],
        Optional[torch.Tensor],
    ]:
        with torch.no_grad():
            (
                output,
                encoder_attn_weights,
                decoder_sa_attn_weights,
                decoder_mha_attn_weights,
            ) = self.forward(src, src_mask, tgt=tgt, temperature=temperature)

        if temperature > 0:
            probs = torch.softmax(output, dim=-1)
            prediction = torch.multinomial(
                probs.view(-1, probs.size(-1)),
                num_samples=1,
                replacement=True,
            ).view(-1, *probs.size()[:-1])
        else:
            prediction = torch.argmax(output, dim=-1)

        return (
            prediction,
            encoder_attn_weights,
            decoder_sa_attn_weights,
            decoder_mha_attn_weights,
        )


class PatchEmbedding(nn.Module):
    def __init__(
        self,
        num_classes: int,
        patch_size: int,
        embed_dim: int,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.patch_size = patch_size
        self.embed_dim = embed_dim

        # Convolutional layer for patch embedding
        self.conv_embed = nn.Conv2d(
            in_channels=self.num_classes,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

    def forward(
        self, x: torch.Tensor, mask: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size = x.shape[0]

        x = (
            F.one_hot(x.long(), num_classes=self.num_classes)
            .permute(0, 3, 1, 2)
            .float()
        )

        x = self.conv_embed(x)

        x = x.permute(0, 2, 3, 1).reshape(batch_size, -1, self.embed_dim)

        mask = nn.functional.avg_pool2d(
            mask.float(),
            self.patch_size,
            stride=self.patch_size,
        )

        mask = (mask > 0).reshape(batch_size, -1)

        return (x, mask)


class ARCVisionEncoderDecoder(nn.Module):
    def __init__(self, params: ARCTransformerEncoderDecoderParams):
        super().__init__()
        self.grid_dim = params.grid_dim
        self.num_train_pairs = params.num_train_pairs
        self.num_classes = params.num_colors + 1
        self.d_model = params.d_model
        self.num_encoder_layers = params.num_encoder_layers
        self.num_decoder_layers = params.num_decoder_layers
        self.num_heads = params.num_heads
        self.d_ff = params.d_ff
        self.dropout = params.dropout
        self.patch_size = 2

        num_grids = self.num_train_pairs * 2 + 1
        self.patch_grid_dim = self.grid_dim // self.patch_size
        self.seq_len = num_grids * self.patch_grid_dim * self.patch_grid_dim

        self.embedding = PatchEmbedding(
            num_classes=self.num_classes,
            patch_size=self.patch_size,
            embed_dim=self.d_model,
        )
        self.pos_encoding = ARCPositionalEncoding(
            d_model=self.d_model,
            grid_dim=self.patch_grid_dim,
            num_train_pairs=self.num_train_pairs,
        )

        encoder_layer = EncoderLayerWithAttention(
            self.d_model, self.num_heads, self.d_ff, self.dropout
        )
        self.encoder = EncoderWithAttention(encoder_layer, self.num_encoder_layers)

        decoder_layer = DecoderLayerWithAttention(
            self.d_model, self.num_heads, self.d_ff, self.dropout
        )
        self.decoder = DecoderWithAttention(decoder_layer, self.num_decoder_layers)

        self.output_query = nn.Parameter(
            torch.randn(1, self.grid_dim * self.grid_dim, self.d_model)
        )
        self.output_layer = nn.Linear(self.d_model, self.num_classes)

    def forward(
        self, src: torch.Tensor, src_mask: torch.Tensor, need_weights: bool = False
    ):
        batch_size, num_grids, grid_dim, grid_dim = src.shape

        # Flatten grids
        src = src.reshape(batch_size, num_grids * grid_dim, grid_dim)
        src_mask = src_mask.reshape(batch_size, num_grids * grid_dim, grid_dim)

        # Apply patch embedding
        src_patches, mask_patches = self.embedding.forward(src, src_mask)

        # Apply positional encoding
        pos_emb_patches = self.pos_encoding.forward(
            src_patches.reshape(
                batch_size,
                num_grids,
                self.patch_grid_dim,
                self.patch_grid_dim,
                self.d_model,
            )
        )
        pos_emb_patches = pos_emb_patches.reshape(-1, self.d_model)

        src_patches = src_patches.reshape(batch_size, -1, self.d_model)
        src_patches = src_patches + pos_emb_patches

        # Invert padding mask
        padding_mask = ~mask_patches

        # Encode input
        memory, encoder_attn_weights = self.encoder.forward(
            src_patches, src_key_padding_mask=padding_mask, need_weights=need_weights
        )

        # Prepare output query
        output_query = self.output_query.expand(batch_size, -1, -1)

        # Decode
        (
            output,
            decoder_sa_attn_weights,
            decoder_mha_attn_weights,
        ) = self.decoder.forward(
            output_query,
            memory,
            memory_key_padding_mask=padding_mask,
            need_weights=need_weights,
        )

        # Generate output patches
        output = self.output_layer.forward(output)

        # Reshape to grid
        output = output.view(batch_size, self.grid_dim, self.grid_dim, self.num_classes)

        return (
            output,
            encoder_attn_weights,
            decoder_sa_attn_weights,
            decoder_mha_attn_weights,
        )

    def generate(
        self, src: torch.Tensor, src_mask: torch.Tensor, need_weights: bool = False
    ):
        with torch.no_grad():
            (
                output,
                encoder_attn_weights,
                decoder_sa_attn_weights,
                decoder_mha_attn_weights,
            ) = self.forward(src, src_mask, need_weights)
            return (
                torch.argmax(output, dim=-1),
                encoder_attn_weights,
                decoder_sa_attn_weights,
                decoder_mha_attn_weights,
            )


class ARCVisionEncoder(nn.Module):
    def __init__(self, params: ARCTransformerEncoderDecoderParams):
        super().__init__()
        self.grid_dim = params.grid_dim
        self.num_train_pairs = params.num_train_pairs
        self.num_classes = params.num_colors + 1
        self.d_model = params.d_model
        self.num_layers = params.num_encoder_layers
        self.num_heads = params.num_heads
        self.d_ff = params.d_ff
        self.dropout = params.dropout
        self.patch_size = 2

        self.patch_grid_dim = self.grid_dim // self.patch_size

        self.input_seq_len = (
            (self.num_train_pairs * 2 + 1) * self.patch_grid_dim * self.patch_grid_dim
        )
        self.output_seq_len = self.grid_dim * self.grid_dim
        self.seq_len = self.input_seq_len + self.output_seq_len

        self.embedding = PatchEmbedding(
            num_classes=self.num_classes,
            patch_size=self.patch_size,
            embed_dim=self.d_model,
        )
        self.tgt_embedding = nn.Embedding(self.num_classes, self.d_model)
        self.pos_encoding = ARCPositionalEncodingV2(
            d_model=self.d_model,
            grid_dim=self.grid_dim,
            num_train_pairs=self.num_train_pairs,
        )

        encoder_layer = EncoderLayerWithAttention(
            self.d_model, self.num_heads, self.d_ff, self.dropout
        )
        self.encoder = EncoderWithAttention(encoder_layer, self.num_layers)

        self.output_query = nn.Parameter(
            torch.randn(1, 1, self.grid_dim, self.grid_dim, self.d_model)
        )
        self.output_layer = nn.Linear(self.d_model, self.num_classes)

    def embed(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        tgt: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_grids, grid_dim, grid_dim = src.shape

        pos_emb = self.pos_encoding.forward(
            num_grids=num_grids + 1, grid_dim=grid_dim, device=src.device
        )

        src = src.reshape(batch_size, num_grids * grid_dim, grid_dim)
        src_mask = src_mask.reshape(batch_size, num_grids * grid_dim, grid_dim)

        src_patched, mask_patched = self.embedding.forward(src, src_mask)

        input_pos_emb = pos_emb[:-1, :, :, :]
        input_pos_emb_patched = (
            input_pos_emb.unfold(1, self.patch_size, self.patch_size)
            .unfold(2, self.patch_size, self.patch_size)
            .mean(dim=(-2, -1))
        )

        src_patched = src_patched.reshape(batch_size, -1, self.d_model)
        input_pos_emb_patched = input_pos_emb_patched.reshape(-1, self.d_model)
        input_seq = src_patched + input_pos_emb_patched

        if tgt is not None:
            output_query = self.tgt_embedding.forward(tgt).view(
                batch_size, 1, self.grid_dim, self.grid_dim, self.d_model
            )
        else:
            output_query = self.output_query.expand(batch_size, -1, -1, -1, -1)
        output_pos_emb = pos_emb[-1:, :, :, :]

        output_query = output_query.reshape(batch_size, -1, self.d_model)
        output_pos_emb = output_pos_emb.reshape(-1, self.d_model)
        output_seq = output_query + output_pos_emb

        combined_seq = torch.cat([input_seq, output_seq], dim=1)

        # Make padding mask
        padding_mask = ~mask_patched
        padding_mask = torch.cat(
            [
                padding_mask,
                torch.zeros(
                    (batch_size, self.grid_dim**2), dtype=torch.bool, device=src.device
                ),
            ],
            dim=1,
        )

        return combined_seq, padding_mask

    def forward(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        tgt: Optional[torch.Tensor] = None,
        temperature: float = 0.0,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size = src.shape[0]

        combined_seq, padding_mask = self.embed(src, src_mask, tgt)

        # Make causal mask
        causal_mask = torch.zeros(self.seq_len, self.seq_len, device=src.device)
        causal_mask[: self.input_seq_len, self.input_seq_len :] = 1
        causal_mask = causal_mask.bool()

        output, attn_weights = self.encoder.forward(
            combined_seq,
            mask=causal_mask,
            src_key_padding_mask=padding_mask,
            need_weights=need_weights,
        )

        output_grid_portion = output[:, -self.output_seq_len :, :]

        logits = self.output_layer.forward(output_grid_portion)

        output = logits.view(batch_size, self.grid_dim, self.grid_dim, self.num_classes)

        if temperature > 0:
            output = output / temperature

        return (output, attn_weights)

    def generate(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        tgt: Optional[torch.Tensor] = None,
        temperature: float = 0.0,
        need_weights: bool = False,
    ):
        with torch.no_grad():
            (
                output,
                encoder_attn_weights,
            ) = self.forward(src, src_mask, tgt=tgt, temperature=temperature)
        if temperature > 0:
            probs = torch.softmax(output, dim=-1)
            prediction = torch.multinomial(
                probs.view(-1, probs.size(-1)),
                num_samples=1,
                replacement=True,
            ).view(-1, *probs.size()[:-1])
        else:
            prediction = torch.argmax(output, dim=-1)
        return prediction, encoder_attn_weights


class ARCPositionalEncodingV2(nn.Module):
    def __init__(
        self,
        d_model: int,
        grid_dim: int,
        num_train_pairs: int,
    ):
        super().__init__()
        self.d_model = d_model
        self.grid_dim = grid_dim
        self.num_train_pairs = num_train_pairs

        # Embeddings for row and column positions
        self.row_embedding = nn.Embedding(self.grid_dim, self.d_model // 4)
        self.col_embedding = nn.Embedding(self.grid_dim, self.d_model // 4)

        # Embedding for input vs output
        self.input_output_embedding = nn.Embedding(2, d_model // 4)

        # Embedding for training pair index
        self.pair_embedding = nn.Embedding(
            self.num_train_pairs + 1, d_model // 4
        )  # +1 for test pair

    @torch.compiler.disable
    def forward(
        self, num_grids: int, grid_dim: int, device: torch.device
    ) -> torch.Tensor:
        grid_pos = torch.arange(grid_dim, device=device)

        # Row pos embedding
        row_emb = (
            self.row_embedding.forward(grid_pos)
            .unsqueeze(1)
            .expand(num_grids, -1, grid_dim, -1)
        )

        # Column pos embedding
        col_emb = (
            self.col_embedding.forward(grid_pos)
            .unsqueeze(0)
            .expand(num_grids, grid_dim, -1, -1)
        )

        # Input/output embedding
        grid_indices = torch.arange(num_grids, device=device)
        is_output = (grid_indices % 2 == 1).long()
        io_emb = (
            self.input_output_embedding(is_output)
            .unsqueeze(1)
            .unsqueeze(1)
            .expand(num_grids, grid_dim, grid_dim, -1)
        )

        # Pair embedding
        pair_indices = torch.div(grid_indices, 2, rounding_mode="floor")
        pair_emb = (
            self.pair_embedding(pair_indices)
            .unsqueeze(1)
            .unsqueeze(1)
            .expand(num_grids, grid_dim, grid_dim, -1)
        )

        # Combine all embeddings (1, num_grids, height, width, d_model)
        combined_emb = torch.cat([row_emb, col_emb, io_emb, pair_emb], dim=-1)

        return combined_emb


class ARCPositionalEncoding(nn.Module):
    def __init__(
        self,
        d_model: int,
        grid_dim: int,
        num_train_pairs: int,
    ):
        super().__init__()
        self.d_model = d_model
        self.grid_dim = grid_dim
        self.num_train_pairs = num_train_pairs

        # Embeddings for row and column positions
        self.row_embedding = nn.Embedding(self.grid_dim, self.d_model // 4)
        self.col_embedding = nn.Embedding(self.grid_dim, self.d_model // 4)

        # Embedding for input vs output
        self.input_output_embedding = nn.Embedding(2, d_model // 4)

        # Embedding for training pair index
        self.pair_embedding = nn.Embedding(
            self.num_train_pairs + 1, d_model // 4
        )  # +1 for test pair

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, num_grids, height, width, _ = x.size()
        device = x.device

        # Row pos embedding
        row_pos = torch.arange(height, device=device)
        row_emb = (
            self.row_embedding.forward(row_pos)
            .unsqueeze(1)
            .expand(num_grids, -1, width, -1)
        )

        # Column pos embedding
        col_pos = torch.arange(width, device=device)
        col_emb = (
            self.col_embedding.forward(col_pos)
            .unsqueeze(0)
            .expand(num_grids, height, -1, -1)
        )

        # Input/output embedding
        grid_indices = torch.arange(num_grids, device=device)
        is_output = (grid_indices % 2 == 1).long()
        io_emb = (
            self.input_output_embedding(is_output)
            .unsqueeze(1)
            .unsqueeze(1)
            .expand(num_grids, height, width, -1)
        )

        # Pair embedding
        pair_indices = torch.div(grid_indices, 2, rounding_mode="floor")
        pair_indices[-1] = self.num_train_pairs
        pair_emb = (
            self.pair_embedding(pair_indices)
            .unsqueeze(1)
            .unsqueeze(1)
            .expand(num_grids, height, width, -1)
        )

        # Combine all embeddings (1, num_grids, height, width, d_model)
        combined_emb = torch.cat([row_emb, col_emb, io_emb, pair_emb], dim=-1)

        return combined_emb


In [5]:
def unpad_grid(grid: torch.Tensor) -> list[list[int]]:
    grid = grid - 1
    filtered_rows: list[list] = []
    for row in grid:
        filtered_row = row[row != -1]
        if len(filtered_row) > 0:
            filtered_rows.append(filtered_row.tolist())
    # Hack to ensure there's always at least one value
    if len(filtered_rows) == 0:
        filtered_rows.append([0])
    max_length = max(len(row) for row in filtered_rows)
    padded_rows = [(row + [0] * (max_length - len(row))) for row in filtered_rows]
    return padded_rows

class FinetuneDataset(Dataset):
    tasks: list[list[list[list[list[int]]]]]
    config: ARCDatasetParams

    def __init__(
        self,
        tasks: list[list[list[list[list[int]]]]],
        config: ARCDatasetParams,
    ):
        self.tasks = tasks
        self.config = config

    def __len__(self) -> int:
        return len(self.tasks)

    def __getitem__(self, idx: int) -> dict:
        task = self.tasks[idx]

        grids = torch.zeros(
            2 * self.config.max_train_grids + 1,
            self.config.max_grid_size,
            self.config.max_grid_size,
            dtype=torch.int,
        )
        masks = torch.zeros(
            2 * self.config.max_train_grids + 1,
            self.config.max_grid_size,
            self.config.max_grid_size,
            dtype=torch.bool,
        )

        for i, pair in enumerate(task[:-1]):
            if i >= self.config.max_train_grids:
                print("Training pairs exceed max", i, self.config.max_train_grids)
                break

            input_grid, input_mask = pad_and_mask_grid(pair[0], self.config)
            output_grid, output_mask = pad_and_mask_grid(pair[1], self.config)
            grids[2 * i] = input_grid
            masks[2 * i] = input_mask
            grids[2 * i + 1] = output_grid
            masks[2 * i + 1] = output_mask

        test_input_grid, test_input_mask = pad_and_mask_grid(task[-1][0], self.config)
        grids[-1] = test_input_grid
        masks[-1] = test_input_mask

        test_output_grid = pad_and_mask_grid(task[-1][1], self.config)[0]

        return {
            "grids": grids,
            "masks": masks,
            "output": test_output_grid,
        }


def make_finetune_dataset(
    grids: torch.Tensor, config: ARCDatasetParams
) -> FinetuneDataset:
    if len(grids.shape) == 3:
        grids = grids.unsqueeze(0)
    if len(grids.shape) != 4:
        raise Exception("incorrect grids dimension")
    tasks = []
    for task in grids:
        pairs = task[:-1].reshape(
            config.max_train_grids,
            2,
            config.max_grid_size,
            config.max_grid_size,
        )
        finetune_pairs: list[list[list[list[int]]]] = []
        for pair in pairs:
            finetune_pairs.append([unpad_grid(grid) for grid in pair])

        for length in range(3, len(finetune_pairs) + 1):
            for combination in itertools.combinations(finetune_pairs, length):
                for permutation in itertools.permutations(combination):
                    tasks.append(list(permutation))

    return FinetuneDataset(tasks=tasks, config=config)

In [6]:
@dataclass(frozen=True)
class ARCTrainParams:
    batch_size: int
    learning_rate: float
    weight_decay: float
    dataset_dir: list[str]
    loss_class_weights: Optional[dict[int, float]] = None
    meta_batch_size: Optional[int] = None
    meta_learning_rate: Optional[float] = None
    meta_weight_decay: Optional[float] = None
    meta_num_epochs: Optional[int] = None
    train_steps_per_epoch: Optional[int] = None
    eval_steps_per_epoch: Optional[int] = None
    warmup_epochs: Optional[int] = None
    refinement_ratio: Optional[float] = None


def finetune_collate_arc_fn(
    batch: list[dict],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    grids = torch.stack([item["grids"] for item in batch])
    masks = torch.stack([item["masks"] for item in batch])
    output = torch.stack([item["output"] for item in batch])

    return (grids, masks, output)

def fine_tune_transformer(
    model: nn.Module,
    finetune_params: ARCTrainParams,
    dataset: FinetuneDataset,
    num_epochs: int,
    accuracy_cutoff: float = 0.99,
) -> nn.Module:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.backends.mha.set_fastpath_enabled(False)

    model = copy.deepcopy(model)
    model = model.to(device)

    data_loader = DataLoader(
        dataset,
        batch_size=finetune_params.batch_size,
        shuffle=True,
        collate_fn=finetune_collate_arc_fn,
        num_workers=0,
    )

    print(f"Starting fine-tuning run with dataset of {len(dataset)} training items")
    print(f"Using batch size of {finetune_params.batch_size}")

    class_weights = torch.ones(model.num_classes).to(device)
    if finetune_params.loss_class_weights is not None:
        for cls, weight in finetune_params.loss_class_weights.items():
            class_weights[cls] = weight

    criterion = nn.CrossEntropyLoss(weight=class_weights)

    optimizer = optim.AdamW(
        model.parameters(),
        lr=finetune_params.learning_rate,
        weight_decay=finetune_params.weight_decay,
    )

    scaler = GradScaler(device.type)

    model.train()

    for epoch in range(num_epochs):
        train_loss = 0.0
        train_accuracy = 0.0

        for batch in data_loader:
            grids, masks, target_grid = [item.to(device) for item in batch]

            optimizer.zero_grad()

            with autocast(device.type):
                output = model.forward(grids, masks)[0]
                loss = criterion(
                    output.view(-1, model.num_classes),
                    target_grid.view(-1).long(),
                )

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()

            predictions = torch.argmax(output, dim=-1)
            train_accuracy += (predictions == target_grid).float().mean().item()

        train_loss /= len(data_loader)
        train_accuracy /= len(data_loader)

        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")

        if train_accuracy >= accuracy_cutoff:
            print("Stopping early because accuracy exceeds accuracy cut-off")
            break

    print("Fine-tuning completed")
    return model

def finetune_and_predict(
    model: ARCVisionEncoder | ARCTransformerEncoder,
    finetune_params: ARCTrainParams,
    dataset_params: ARCDatasetParams,
    grids: torch.Tensor,
    masks: torch.Tensor,
    num_finetune_epochs: int = 2,
    temperature: list[float] = [0.0],
    accuracy_cutoff: float = 0.99,
    num_predictions: int = 1,
) -> list[torch.Tensor]:
    model.eval()

    finetune_dataset = make_finetune_dataset(grids, dataset_params)

    finetune_model = fine_tune_transformer(
        model,
        finetune_params,
        finetune_dataset,
        num_finetune_epochs,
        accuracy_cutoff,
    )

    finetune_predictions = finetune_model.generate(
        grids,
        masks,
        temperature=temperature[0],
        need_weights=False
    )[0][0]

    refine_temperature = temperature[1] if len(temperature) > 1 else temperature[0]

    refined_predictions = []
    for _ in range(num_predictions):
        refined_prediction = finetune_model.generate(
            grids,
            masks,
            temperature=refine_temperature,
            tgt=finetune_predictions,
            need_weights=False
        )[0][0]
        refined_predictions.append(refined_prediction)
    

    return refined_predictions

In [7]:
# Load dataset
data_params = ARCDatasetParams(max_grid_size=12, max_train_grids=4, color_offset=1)
test_dataset = ARCKaggleDataset(challenges_file=test_challenges_file_name, config=data_params)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_params = ARCTransformerEncoderDecoderParams(
    grid_dim=12,
    num_train_pairs=4,
    num_colors=10,
    num_encoder_layers=16,
    num_decoder_layers=0,
    num_heads=16,
    d_model=512,
    d_ff=3072,
    dropout=0.2,
)
model = ARCTransformerEncoder(model_params).to(device)
model.load_state_dict(torch.load(kaggle_model_file_path, map_location=device, weights_only=True))

finetune_params = ARCTrainParams(
    batch_size=4,
    learning_rate=1e-5,
    loss_class_weights={0: 0.2},
    dataset_dir=[],
    weight_decay=1e-5,
)
num_finetune_epochs = 12
accuracy_cutoff = 0.99
temperature = [0.4, 0.2]

model.eval()

def unpad_output(output: torch.Tensor) -> torch.Tensor:
    output = output - 1
    filtered_rows = []
    for row in output:
        filtered_row = row[row != -1]
        if len(filtered_row) > 0:
            filtered_rows.append(filtered_row)
    # Hack to ensure there's always at least one value
    if len(filtered_rows) == 0:
        filtered_rows.append(torch.zeros(1, device=device, dtype=torch.long))
    max_length = max(len(row) for row in filtered_rows)
    padded_rows = [torch.cat([row, torch.zeros(max_length - len(row), dtype=row.dtype, device=device)]) for row in filtered_rows]
    return torch.stack(padded_rows)

kaggle_output = {}


for task in test_dataset:
    print(f"Starting {task['task_id']}")
    task_predictions = []
    for grids, masks in zip(task["grids"], task["masks"]):
        predictions = finetune_and_predict(
            model,
            finetune_params,
            data_params,
            grids.unsqueeze(0).to(device),
            masks.unsqueeze(0).to(device),
            num_finetune_epochs=num_finetune_epochs,
            temperature=temperature,
            accuracy_cutoff=accuracy_cutoff,
            num_predictions=2
        )
        attempt_1 = unpad_output(predictions[0][0]).cpu().numpy().tolist()
        attempt_2 = unpad_output(predictions[1][0]).cpu().numpy().tolist()
        task_predictions.append({
            "attempt_1":  attempt_1,
            "attempt_2": attempt_2
        })
    kaggle_output[task["task_id"]] = task_predictions

In [14]:
with open(kaggle_submission_file_path, "w") as f:
  json.dump(kaggle_output, f)

In [None]:
import numpy as np

def grade_submission(submission_filename: str, solutions_filename: str):
  with open(solutions_filename, "r") as f:
    solutions = json.load(f)
  with open(submission_filename, "r") as f:
    submissions = json.load(f)

  correct = 0
  total = 0
  attempted = 0

  print(submissions["22168020"])
  for k, submission in submissions.items():
  # for k, item in solutions.items():
      # submission = submissions[k]
      
      item = solutions[k]
      if len(item[0]) <=12 and len(item[0][0]) <= 12:
        attempted += 1
      # if len(submission[0]["attempt_1"]) > 1:
      #   attempted += 1
      if np.array_equal(item[0], submission[0]["attempt_1"]) is True:
        correct += 1
      total += 1

  print(correct, attempted, total, correct / attempted, correct / total)

# test_solutions_file_name = "data/arc-agi_evaluation_solutions.json"
test_solutions_file_name = "data/arc-agi_training_solutions.json"
kaggle_submission_file_path = "kaggle/submission_1.json"
grade_submission(kaggle_submission_file_path, test_solutions_file_name)