In [1]:
import json
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [None]:
test_challenges_file_name = "data/arc-agi_evaluation_challenges.json"
kaggle_model_file_path = "kaggle/models/subtly_known_panda.pth"
kaggle_submission_file_path = "kaggle/submission.json"

In [2]:
@dataclass(frozen=True)
class ARCDatasetParams:
    max_grid_size: int = 30
    max_train_grids: int = 10
    color_offset: int = 1


def pad_and_mask_grid(
    grid: list[list[int]], config: ARCDatasetParams
) -> tuple[torch.Tensor, torch.Tensor]:
    h, w = len(grid), len(grid[0])
    if h > config.max_grid_size or w > config.max_grid_size:
        raise Exception("grid size too large")

    padded = torch.zeros((config.max_grid_size, config.max_grid_size), dtype=torch.int)
    mask = torch.zeros((config.max_grid_size, config.max_grid_size), dtype=torch.bool)

    # Calculate padding
    pad_h = (config.max_grid_size - h) // 2
    pad_w = (config.max_grid_size - w) // 2

    # Place the grid in the center
    padded[pad_h : pad_h + h, pad_w : pad_w + w] = (
        torch.tensor(grid, dtype=torch.int) + config.color_offset
    )
    mask[pad_h : pad_h + h, pad_w : pad_w + w] = True

    return (padded, mask)


class ARCKaggleDataset(Dataset):
    challenges: dict[str, dict]
    task_ids: list[str]
    config: ARCDatasetParams

    def __init__(
        self,
        challenges_file: str,
        config: ARCDatasetParams,
    ):
        with open(challenges_file, "r") as f:
            self.challenges = json.load(f)
            self.task_ids = list(self.challenges.keys())
        self.config = config

    def __len__(self) -> int:
        return len(self.task_ids)

    def __getitem__(self, idx: int) -> dict:
        task_id = self.task_ids[idx]
        challenge = self.challenges[task_id]

        all_grids = []
        all_masks = []

        for test in challenge["test"]:
            grids = torch.zeros(
                2 * self.config.max_train_grids + 1,
                self.config.max_grid_size,
                self.config.max_grid_size,
                dtype=torch.int,
            )
            masks = torch.zeros(
                2 * self.config.max_train_grids + 1,
                self.config.max_grid_size,
                self.config.max_grid_size,
                dtype=torch.bool,
            )

            for i, pair in enumerate(challenge["train"]):
                if i >= self.config.max_train_grids:
                    print(
                        "Training pairs exceed max", task_id, i, self.config.max_train_grids
                    )
                    break

                try:
                    input_grid, input_mask = pad_and_mask_grid(pair["input"], self.config)
                    output_grid, output_mask = pad_and_mask_grid(
                        pair["output"], self.config
                    )
                    grids[2 * i] = input_grid
                    masks[2 * i] = input_mask
                    grids[2 * i + 1] = output_grid
                    masks[2 * i + 1] = output_mask
                except Exception as e:
                    print("Got exception for training pair", task_id, i, e)

            try:
                test_input_grid, test_input_mask = pad_and_mask_grid(
                    test["input"], self.config
                )
                grids[-1] = test_input_grid
                masks[-1] = test_input_mask
            except Exception as e:
                print("Got exception on test input", task_id, e)
            
            all_grids.append(grids)
            all_masks.append(masks)

        return {"task_id": task_id, "grids": torch.stack(all_grids), "masks": torch.stack(all_masks)}
    

def collate_arc_fn(
    batch: list[dict],
) -> tuple[list, torch.Tensor, torch.Tensor]:
    task_ids = [item["task_id"] for item in batch]
    grids = torch.stack([item["grids"] for item in batch])
    masks = torch.stack([item["masks"] for item in batch])

    return (task_ids, grids, masks)

In [3]:
@dataclass(frozen=True)
class ARCTransformerEncoderDecoderParams:
    grid_dim: int
    num_train_pairs: int
    num_colors: int
    num_encoder_layers: int
    num_decoder_layers: int
    num_heads: int
    d_model: int
    d_ff: int
    dropout: float

class ARCPositionalEncodingV2(nn.Module):
    def __init__(
        self,
        d_model: int,
        grid_dim: int,
        num_train_pairs: int,
    ):
        super().__init__()
        self.d_model = d_model
        self.grid_dim = grid_dim
        self.num_train_pairs = num_train_pairs

        # Embeddings for row and column positions
        self.row_embedding = nn.Embedding(self.grid_dim, self.d_model // 4)
        self.col_embedding = nn.Embedding(self.grid_dim, self.d_model // 4)

        # Embedding for input vs output
        self.input_output_embedding = nn.Embedding(2, d_model // 4)

        # Embedding for training pair index
        self.pair_embedding = nn.Embedding(
            self.num_train_pairs + 1, d_model // 4
        )  # +1 for test pair

    @torch.compiler.disable
    def forward(
        self, num_grids: int, grid_dim: int, device: torch.device
    ) -> torch.Tensor:
        grid_pos = torch.arange(grid_dim, device=device)

        # Row pos embedding
        row_emb = (
            self.row_embedding.forward(grid_pos)
            .unsqueeze(1)
            .expand(num_grids, -1, grid_dim, -1)
        )

        # Column pos embedding
        col_emb = (
            self.col_embedding.forward(grid_pos)
            .unsqueeze(0)
            .expand(num_grids, grid_dim, -1, -1)
        )

        # Input/output embedding
        grid_indices = torch.arange(num_grids, device=device)
        is_output = (grid_indices % 2 == 1).long()
        io_emb = (
            self.input_output_embedding(is_output)
            .unsqueeze(1)
            .unsqueeze(1)
            .expand(num_grids, grid_dim, grid_dim, -1)
        )

        # Pair embedding
        pair_indices = torch.div(grid_indices, 2, rounding_mode="floor")
        pair_emb = (
            self.pair_embedding(pair_indices)
            .unsqueeze(1)
            .unsqueeze(1)
            .expand(num_grids, grid_dim, grid_dim, -1)
        )

        # Combine all embeddings (1, num_grids, height, width, d_model)
        combined_emb = torch.cat([row_emb, col_emb, io_emb, pair_emb], dim=-1)

        return combined_emb

class PatchEmbedding(nn.Module):
    def __init__(
        self,
        num_classes: int,
        patch_size: int,
        embed_dim: int,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.patch_size = patch_size
        self.embed_dim = embed_dim

        # Convolutional layer for patch embedding
        self.conv_embed = nn.Conv2d(
            in_channels=self.num_classes,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

    def forward(
        self, x: torch.Tensor, mask: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size = x.shape[0]

        x = (
            F.one_hot(x.long(), num_classes=self.num_classes)
            .permute(0, 3, 1, 2)
            .float()
        )

        x = self.conv_embed(x)

        x = x.permute(0, 2, 3, 1).reshape(batch_size, -1, self.embed_dim)

        mask = nn.functional.avg_pool2d(
            mask.float(),
            self.patch_size,
            stride=self.patch_size,
        )

        mask = (mask > 0).reshape(batch_size, -1)

        return (x, mask)


class EncoderLayerWithAttention(nn.TransformerEncoderLayer):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ) -> None:
        super().__init__(d_model, nhead, dim_feedforward, dropout, batch_first=True)

    def forward(
        self,
        src: torch.Tensor,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(src_mask),
            other_name="src_mask",
            target_type=src.dtype,
        )

        src_mask = F._canonical_mask(
            mask=src_mask,
            mask_name="src_mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        x = src
        x1, attn_weights = self.self_attn(
            x,
            x,
            x,
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask,
            need_weights=need_weights,
            is_causal=is_causal,
            average_attn_weights=False,
        )

        x = x + self.dropout1(x1)
        x = self.norm1(x)

        x1 = self.linear2(self.dropout(self.activation(self.linear1(x))))
        x = x + self.dropout2(x1)
        x = self.norm2(x)

        return x, attn_weights


class EncoderWithAttention(nn.TransformerEncoder):
    def __init__(
        self,
        encoder_layer: "EncoderLayerWithAttention",
        num_layers: int,
    ) -> None:
        super().__init__(encoder_layer, num_layers)

    def forward(
        self,
        src: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(mask),
            other_name="mask",
            target_type=src.dtype,
        )

        mask = F._canonical_mask(
            mask=mask,
            mask_name="mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        output = src
        attn_weights = []

        for mod in self.layers:
            output, layer_attn_weights = mod(
                output,
                src_mask=mask,
                src_key_padding_mask=src_key_padding_mask,
                need_weights=need_weights,
            )
            if need_weights:
                attn_weights.append(layer_attn_weights)

        if self.norm is not None:
            output = self.norm(output)

        return output, (torch.stack(attn_weights, dim=1) if need_weights else None)
    
class ARCVisionEncoder(nn.Module):
    def __init__(self, params: ARCTransformerEncoderDecoderParams):
        super().__init__()
        self.grid_dim = params.grid_dim
        self.num_train_pairs = params.num_train_pairs
        self.num_classes = params.num_colors + 1
        self.d_model = params.d_model
        self.num_layers = params.num_encoder_layers
        self.num_heads = params.num_heads
        self.d_ff = params.d_ff
        self.dropout = params.dropout
        self.patch_size = 2

        self.patch_grid_dim = self.grid_dim // self.patch_size

        self.input_seq_len = (
            (self.num_train_pairs * 2 + 1) * self.patch_grid_dim * self.patch_grid_dim
        )
        self.output_seq_len = self.grid_dim * self.grid_dim
        self.seq_len = self.input_seq_len + self.output_seq_len

        self.embedding = PatchEmbedding(
            num_classes=self.num_classes,
            patch_size=self.patch_size,
            embed_dim=self.d_model,
        )
        self.tgt_embedding = nn.Embedding(self.num_classes, self.d_model)
        self.pos_encoding = ARCPositionalEncodingV2(
            d_model=self.d_model,
            grid_dim=self.grid_dim,
            num_train_pairs=self.num_train_pairs,
        )

        encoder_layer = EncoderLayerWithAttention(
            self.d_model, self.num_heads, self.d_ff, self.dropout
        )
        self.encoder = EncoderWithAttention(encoder_layer, self.num_layers)

        self.output_query = nn.Parameter(
            torch.randn(1, 1, self.grid_dim, self.grid_dim, self.d_model)
        )
        self.output_layer = nn.Linear(self.d_model, self.num_classes)

    def embed(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        tgt: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_grids, grid_dim, grid_dim = src.shape

        pos_emb = self.pos_encoding.forward(
            num_grids=num_grids + 1, grid_dim=grid_dim, device=src.device
        )

        src = src.reshape(batch_size, num_grids * grid_dim, grid_dim)
        src_mask = src_mask.reshape(batch_size, num_grids * grid_dim, grid_dim)

        src_patched, mask_patched = self.embedding.forward(src, src_mask)

        input_pos_emb = pos_emb[:-1, :, :, :]
        input_pos_emb_patched = (
            input_pos_emb.unfold(1, self.patch_size, self.patch_size)
            .unfold(2, self.patch_size, self.patch_size)
            .mean(dim=(-2, -1))
        )

        src_patched = src_patched.reshape(batch_size, -1, self.d_model)
        input_pos_emb_patched = input_pos_emb_patched.reshape(-1, self.d_model)
        input_seq = src_patched + input_pos_emb_patched

        if tgt is not None:
            output_query = self.tgt_embedding.forward(tgt).view(
                batch_size, 1, self.grid_dim, self.grid_dim, self.d_model
            )
        else:
            output_query = self.output_query.expand(batch_size, -1, -1, -1, -1)
        output_pos_emb = pos_emb[-1:, :, :, :]

        output_query = output_query.reshape(batch_size, -1, self.d_model)
        output_pos_emb = output_pos_emb.reshape(-1, self.d_model)
        output_seq = output_query + output_pos_emb

        combined_seq = torch.cat([input_seq, output_seq], dim=1)

        # Make padding mask
        padding_mask = ~mask_patched
        padding_mask = torch.cat(
            [
                padding_mask,
                torch.zeros(
                    (batch_size, self.grid_dim**2), dtype=torch.bool, device=src.device
                ),
            ],
            dim=1,
        )

        return combined_seq, padding_mask

    def forward(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        tgt: Optional[torch.Tensor] = None,
        temperature: float = 0.0,
        need_weights: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size = src.shape[0]

        combined_seq, padding_mask = self.embed(src, src_mask, tgt)

        # Make causal mask
        causal_mask = torch.zeros(self.seq_len, self.seq_len, device=src.device)
        causal_mask[: self.input_seq_len, self.input_seq_len :] = 1
        causal_mask = causal_mask.bool()

        output, attn_weights = self.encoder.forward(
            combined_seq,
            mask=causal_mask,
            src_key_padding_mask=padding_mask,
            need_weights=need_weights,
        )

        output_grid_portion = output[:, -self.output_seq_len :, :]

        logits = self.output_layer.forward(output_grid_portion)

        output = logits.view(batch_size, self.grid_dim, self.grid_dim, self.num_classes)

        if temperature > 0:
            output = output / temperature

        return (output, attn_weights)

    def generate(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor,
        tgt: Optional[torch.Tensor] = None,
        temperature: float = 0.0,
        need_weights: bool = False,
    ):
        with torch.no_grad():
            (
                output,
                encoder_attn_weights,
            ) = self.forward(src, src_mask, tgt=tgt, temperature=temperature)
        if temperature > 0:
            probs = torch.softmax(output, dim=-1)
            prediction = torch.multinomial(
                probs.view(-1, probs.size(-1)),
                num_samples=1,
                replacement=True,
            ).view(-1, *probs.size()[:-1])
        else:
            prediction = torch.argmax(output, dim=-1)
        return prediction, encoder_attn_weights


@dataclass
class ARCKaggleModelState:
    model_params: ARCTransformerEncoderDecoderParams
    model_state_dict: dict

In [4]:
# Load dataset
config = ARCDatasetParams(max_grid_size=20, max_train_grids=4, color_offset=1)
test_dataset = ARCKaggleDataset(challenges_file=test_challenges_file_name, config=config)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_params = ARCTransformerEncoderDecoderParams(grid_dim=20, num_train_pairs=4, num_colors=10, num_encoder_layers=24, num_decoder_layers=0, num_heads=16, d_model=1024, d_ff=1024*4, dropout=0.1)
model = ARCVisionEncoder(model_params).to(device)
model.load_state_dict(torch.load(kaggle_model_file_path, map_location=device, weights_only=True))

model.eval()

def unpad_output(output: torch.Tensor) -> torch.Tensor:
    output = output - 1
    filtered_rows = []
    for row in output:
        filtered_row = row[row != -1]
        if len(filtered_row) > 0:
            filtered_rows.append(filtered_row)
    # Hack to ensure there's always at least one value
    if len(filtered_rows) == 0:
        filtered_rows.append(torch.zeros(1, device=device, dtype=torch.long))
    max_length = max(len(row) for row in filtered_rows)
    padded_rows = [torch.cat([row, torch.zeros(max_length - len(row), dtype=row.dtype, device=device)]) for row in filtered_rows]
    return torch.stack(padded_rows)

kaggle_output = {}

for task in test_dataset:
    print(f"Starting {task['task_id']}")
    task_predictions = []
    for grids, masks in zip(task["grids"], task["masks"]):
        prediction = model.generate(grids.unsqueeze(0).to(device), masks.unsqueeze(0).to(device), need_weights=False)[0][0]
        list_prediction = unpad_output(prediction).cpu().numpy().tolist()
        task_predictions.append({
            "attempt_1":  list_prediction,
            "attempt_2": list_prediction
        })
    kaggle_output[task["task_id"]] = task_predictions

In [None]:
with open(kaggle_submission_file_path, "w") as f:
  json.dump(kaggle_output, f)