# Tiny Recursive Reasoning Model (TRM)

This notebook contains the TRM implementation converted from `main.py`. It includes:
- configuration dataclass
- model (TinyRecursiveModel) and residual blocks
- a flexible `GenericDataset` that can load JSON, NPY, PT and CSV (including the provided `sudoku.csv` with `question`/`answer` columns)
- collate function and trainer

Notes:
- The notebook will not start training automatically. Use the example cell at the end to create dataset / loaders and to run training manually.
- Loading 1M sudoku rows into memory may use a lot of RAM. Consider streaming or converting to a binary format if needed.

In [None]:
import csv
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm


In [None]:
@dataclass
class TRMConfig:
    input_dim: int = 256
    hidden_dim: int = 512
    output_dim: int = 256
    L_layers: int = 2
    H_cycles: int = 3
    L_cycles: int = 6
    dropout: float = 0.1

    batch_size: int = 32
    epochs: int = 100
    lr: float = 1e-4
    weight_decay: float = 0.01
    warmup_steps: int = 1000

    data_dir: str = "data/"
    train_split: float = 0.8

    save_dir: str = "checkpoints/"
    save_every: int = 10

    device: str = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class TinyRecursiveModel(nn.Module):
    def __init__(self, config: TRMConfig):
        super().__init__()
        self.config = config

        self.input_proj = nn.Linear(config.input_dim, config.hidden_dim)

        self.latent_layers = nn.ModuleList([
            ResidualBlock(config.hidden_dim, config.dropout)
            for _ in range(config.L_layers)
        ])

        self.output_layers = nn.ModuleList([
            ResidualBlock(config.hidden_dim, config.dropout) for _ in range(2)
        ])

        self.output_proj = nn.Linear(config.hidden_dim, config.output_dim)

        self.latent_gate = nn.Parameter(torch.ones(1))
        self.output_gate = nn.Parameter(torch.ones(1))

    def latent_recursion(self, x, y, z):
        combined = x + y + z
        for layer in self.latent_layers:
            combined = combined + self.latent_gate * layer(combined)
        return combined

    def output_refinement(self, y, z):
        combined = y + z
        for layer in self.output_layers:
            combined = combined + self.output_gate * layer(combined)
        return combined

    def forward(self, x):
        x_embedded = self.input_proj(x)
        y = torch.zeros_like(x_embedded)
        z = torch.zeros_like(x_embedded)

        for _ in range(self.config.H_cycles):
            for _ in range(self.config.L_cycles):
                z = self.latent_recursion(x_embedded, y, z)
            y = self.output_refinement(y, z)

        return self.output_proj(y)

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [None]:
class GenericDataset(Dataset):
    """Load JSON, NPY, PT/PTH and CSV (special handling for Sudoku CSV)"""
    def __init__(self, data_dir: str, split: str = "train"):
        self.data_dir = Path(data_dir)
        self.split = split
        self.data = self._load_data()

    def _parse_sudoku_string(self, s: Optional[str]):
        if s is None:
            return None
        s = ''.join(ch for ch in str(s) if ch.isdigit())
        if len(s) != 81:
            return None
        arr = np.array([int(ch) for ch in s], dtype=np.int32)
        return arr.reshape(9, 9)

    def _load_data(self) -> List[Dict]:
        data: List[Dict] = []

        json_files = list(self.data_dir.glob("*.json"))
        for json_file in json_files:
            with open(json_file, "r") as f:
                file_data = json.load(f)
                if isinstance(file_data, list):
                    data.extend(file_data)
                else:
                    data.append(file_data)

        npy_files = list(self.data_dir.glob("*.npy"))
        for npy_file in npy_files:
            arr = np.load(npy_file, allow_pickle=True)
            if arr.dtype == object:
                data.extend(arr.tolist())
            else:
                data.append({"input": arr, "target": arr})

        pt_files = list(self.data_dir.glob("*.pt")) + list(self.data_dir.glob("*.pth"))
        for pt_file in pt_files:
            tensor_data = torch.load(pt_file)
            if isinstance(tensor_data, dict):
                data.append(tensor_data)
            elif isinstance(tensor_data, (list, tuple)):
                data.extend([{"input": item, "target": item} for item in tensor_data])

        csv_files = list(self.data_dir.glob("*.csv"))
        for csv_file in csv_files:
            with open(csv_file, "r", newline="") as f:
                reader = csv.DictReader(f)
                if reader.fieldnames:
                    q_col = None
                    a_col = None
                    for fn in reader.fieldnames:
                        ln = fn.lower()
                        if ln in ("question", "quiz", "input", "puzzle", "q"):
                            q_col = fn
                        if ln in ("answer", "solution", "target", "sol", "a"):
                            a_col = fn
                    if q_col and a_col:
                        for row in reader:
                            quiz = row.get(q_col)
                            solution = row.get(a_col)
                            q_arr = self._parse_sudoku_string(quiz)
                            s_arr = self._parse_sudoku_string(solution)
                            if q_arr is not None and s_arr is not None:
                                data.append({"input": q_arr, "target": s_arr})
                    else:
                        f.seek(0)
                        for line in f:
                            line = line.strip()
                            if not line:
                                continue
                            parts = line.split(",")
                            if len(parts) < 2:
                                continue
                            quiz, solution = parts[0].strip(), parts[1].strip()
                            q_arr = self._parse_sudoku_string(quiz)
                            s_arr = self._parse_sudoku_string(solution)
                            if q_arr is not None and s_arr is not None:
                                data.append({"input": q_arr, "target": s_arr})
        print(f"Loaded {len(data)} samples from {self.data_dir}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        if isinstance(item, dict):
            input_data = item.get("input", item.get("question", item.get("x")))
            target_data = item.get("target", item.get("answer", item.get("y")))
        else:
            input_data = target_data = item

        def _to_tensor(data):
            if isinstance(data, torch.Tensor):
                t = data
                try:
                    return t.reshape(-1, 1)
                except Exception:
                    return t
            if isinstance(data, np.ndarray):
                arr = data
            elif isinstance(data, list):
                arr = np.array(data)
            else:
                return torch.FloatTensor([data]).reshape(-1, 1)
            if arr.ndim == 2 and arr.shape == (9, 9):
                arr = arr.reshape(-1)
            return torch.FloatTensor(arr).reshape(-1, 1)

        input_data = _to_tensor(input_data)
        target_data = _to_tensor(target_data)

        return input_data, target_data


In [None]:
def collate_fn(batch):
    inputs, targets = zip(*batch)

    max_seq_len = max(x.shape[0] if len(x.shape) > 0 else 1 for x in inputs)

    padded_inputs = []
    padded_targets = []

    for inp, tgt in zip(inputs, targets):
        if len(inp.shape) == 1:
            inp = inp.unsqueeze(-1)
        if len(tgt.shape) == 1:
            tgt = tgt.unsqueeze(-1)

        if inp.shape[0] < max_seq_len:
            pad_size = max_seq_len - inp.shape[0]
            inp = F.pad(inp, (0, 0, 0, pad_size))
            tgt = F.pad(tgt, (0, 0, 0, pad_size))

        padded_inputs.append(inp)
        padded_targets.append(tgt)

    return torch.stack(padded_inputs), torch.stack(padded_targets)


In [None]:
class TRMTrainer:
    def __init__(self, model: TinyRecursiveModel, config: TRMConfig):
        self.model = model.to(config.device)
        self.config = config

        self.optimizer = torch.optim.AdamW(
            model.parameters(), lr=config.lr, weight_decay=config.weight_decay
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=config.epochs
        )

        self.criterion = nn.MSELoss()

        os.makedirs(config.save_dir, exist_ok=True)

    def train_epoch(self, dataloader):
        self.model.train()
        total_loss = 0

        pbar = tqdm(dataloader, desc="Training")
        for inputs, targets in pbar:
            inputs = inputs.to(self.config.device)
            targets = targets.to(self.config.device)

            outputs = self.model(inputs)

            if outputs.shape != targets.shape:
                targets = targets[:, : outputs.shape[1], : outputs.shape[2]]

            loss = self.criterion(outputs, targets)

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        return total_loss / len(dataloader)

    def validate(self, dataloader):
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for inputs, targets in dataloader:
                inputs = inputs.to(self.config.device)
                targets = targets.to(self.config.device)

                outputs = self.model(inputs)

                if outputs.shape != targets.shape:
                    targets = targets[:, : outputs.shape[1], : outputs.shape[2]]

                loss = self.criterion(outputs, targets)
                total_loss += loss.item()

        return total_loss / len(dataloader)

    def train(self, train_loader, val_loader=None):
        print(f"Training TRM with {self.model.get_num_params():,} parameters")
        print(f"Device: {self.config.device}")

        best_val_loss = float("inf")

        for epoch in range(self.config.epochs):
            print(f"\nEpoch {epoch + 1}/{self.config.epochs}")

            train_loss = self.train_epoch(train_loader)
            print(f"Train Loss: {train_loss:.4f}")

            if val_loader:
                val_loss = self.validate(val_loader)
                print(f"Val Loss: {val_loss:.4f}")

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    self.save_checkpoint("best.pt")

            self.scheduler.step()

            if (epoch + 1) % self.config.save_every == 0:
                self.save_checkpoint(f"epoch_{epoch + 1}.pt")

        print("\nTraining complete!")

    def save_checkpoint(self, filename):
        path = os.path.join(self.config.save_dir, filename)
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "config": self.config,
            },
            path,
        )
        print(f"Saved checkpoint: {path}")


In [None]:
# Example: create dataset and a single batch to inspect shapes
data_dir = "data/"
sudoku_csv = os.path.join(data_dir, "sudoku.csv")
if os.path.exists(sudoku_csv):
    print("Detected sudoku.csv - will load the CSV (make sure it has 81-digit quiz and solution strings).")
else:
    print("No sudoku.csv detected in data/. If you have other data formats put them in the data/ folder.")

# Configure for Sudoku if CSV is present
if os.path.exists(sudoku_csv):
    cfg = TRMConfig(
        input_dim=1,
        hidden_dim=128,
        output_dim=1,
        L_layers=2,
        H_cycles=3,
        L_cycles=6,
        batch_size=64,
        epochs=1,
        lr=1e-4,
        data_dir=data_dir,
        save_dir="checkpoints/",
    )
else:
    cfg = TRMConfig(data_dir=data_dir)

ds = GenericDataset(cfg.data_dir)
print(f"Dataset size: {len(ds)}")
if len(ds) > 0:
    inp, tgt = ds[0]
    print("Sample input shape:", inp.shape)
    print("Sample target shape:", tgt.shape)

    # Create a small DataLoader
    loader = DataLoader(ds, batch_size=min(8, len(ds)), collate_fn=collate_fn)
    batch_inputs, batch_targets = next(iter(loader))
    print("Batch inputs shape:", batch_inputs.shape)  # [batch, seq_len, input_dim]
    print("Batch targets shape:", batch_targets.shape)

    # Model sanity check
    model = TinyRecursiveModel(cfg)
    print("Model parameters:", model.get_num_params())
    out = model(batch_inputs)
    print("Model output shape:", out.shape)

    # Ready to train: create trainer and run one epoch locally (optional)
    # trainer = TRMTrainer(model, cfg)
    # trainer.train(loader)
else:
    print("No samples to inspect. Add data to the data/ folder (sudoku.csv or other formats).")


## Next steps and recommendations

- If you want to treat Sudoku as a classification per cell (1-9), change the model's `output_dim` to 9 and use `nn.CrossEntropyLoss` with integer targets in [0..8] or [1..9] mapped appropriately.
- For the 1,000,000-row `sudoku.csv` file, consider converting to `.npy` or writing a streaming Dataset that reads CSV rows in `__getitem__` to avoid loading everything into memory at once.
- If you'd like, I can add a streaming CSV dataset implementation or change the loss/output for classification.
