<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/TRM_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# TRM Sudoku Solver - Direct Colab Version
# Run this in a single Colab cell to test the basic functionality

# Install dependencies
!pip install -q torch torchvision tqdm pandas kaggle

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from tqdm import tqdm
from google.colab import files

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")


PyTorch: 2.8.0+cu126
CUDA: False


In [None]:

# Download dataset directly
!wget -q https://github.com/bryanpark/sudoku/raw/master/sudoku.csv
!ls -lh sudoku.csv


ls: cannot access 'sudoku.csv': No such file or directory


In [None]:

# Model implementation
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.out_proj(out)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1, use_attention=True):
        super().__init__()
        self.use_attention = use_attention
        if use_attention:
            self.attn = MultiHeadAttention(d_model, n_heads, dropout)
            self.norm1 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        if self.use_attention:
            x = self.norm1(x + self.dropout(self.attn(x, mask)))
        x = self.norm2(x + self.dropout(self.ffn(x)))
        return x

class TRM(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, d_ff=512, n_layers=4,
                 n_reasoning_steps=4, n_refinement_steps=8, latent_len=16,
                 use_attention=True, tie_embeddings=True):
        super().__init__()
        self.d_model = d_model
        self.latent_len = latent_len
        self.n_reasoning_steps = n_reasoning_steps
        self.n_refinement_steps = n_refinement_steps
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(512, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, use_attention=use_attention) for _ in range(n_layers)])
        self.reverse_embedding = nn.Linear(d_model, vocab_size, bias=False)
        if tie_embeddings:
            self.reverse_embedding.weight = self.token_embedding.weight
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, std=0.02)

    def apply_blocks(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x

    def forward_pass(self, x, y, z, mask=None):
        combined = torch.cat([x, y, z], dim=1)
        combined = self.apply_blocks(combined, mask)
        len_x, len_y = x.size(1), y.size(1)
        return combined[:, :len_x, :], combined[:, len_x:len_x+len_y, :], combined[:, len_x+len_y:, :]

    def recursive_reasoning(self, x, y, z, mask=None):
        for _ in range(self.n_reasoning_steps):
            _, _, z = self.forward_pass(x, y, z, mask)
        for _ in range(self.n_refinement_steps):
            _, y, _ = self.forward_pass(x, y, z, mask)
        return y

    def forward(self, question_ids, answer_ids=None, mask=None):
        batch, qlen = question_ids.size()
        device = question_ids.device
        x = self.token_embedding(question_ids) + self.position_embedding(torch.arange(qlen, device=device).unsqueeze(0))
        if answer_ids is not None:
            alen = answer_ids.size(1)
            y = self.token_embedding(answer_ids) + self.position_embedding(torch.arange(alen, device=device).unsqueeze(0))
        else:
            y = torch.randn(batch, 81, self.d_model, device=device) * 0.02
        z = torch.randn(batch, self.latent_len, self.d_model, device=device) * 0.02
        y_final = self.recursive_reasoning(x, y, z, mask)
        return self.reverse_embedding(y_final)

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [None]:

# Quick test
vocab_size = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TRM(vocab_size=vocab_size).to(device)
print(f"Model created with {model.count_parameters()/1e6:.2f}M parameters")

# Test with a small sample
print("\nTesting with sample puzzle...")
puzzle = "004300209005009001070060043006002087190007400050083000600000105003508690042910300"
puzzle_tensor = torch.tensor([int(c) for c in puzzle], dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
    output = model(puzzle_tensor)
    pred = torch.argmax(output, dim=-1).squeeze(0)
    print(f"Input:  {puzzle}")
    print(f"Output: {''.join(str(d.item()) for d in pred)}")

Model created with 0.86M parameters

Testing with sample puzzle...
Input:  004300209005009001070060043006002087190007400050083000600000105003508690042910300
Output: 331141218019542486413855918645166134364486963956413462413820350314561110743543530
