<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/TRM_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tiny Recursive Model (TRM) Implementation
Based on **"Let's Build a Tiny Recursive Model from Scratch"** by Azhar.

This notebook builds the TRM (Transformer Reasoning Model) with 3 streams and recursive reasoning.

In [None]:
# Install dependencies
!pip install torch tqdm



In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # reshape for heads
        q = q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.out_proj(out)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1, use_attention=True):
        super().__init__()
        self.use_attention = use_attention
        if use_attention:
            self.attn = MultiHeadAttention(d_model, n_heads, dropout)
            self.norm1 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        if self.use_attention:
            residual = x
            x = self.attn(x, mask)
            x = self.norm1(residual + self.dropout(x))
        residual = x
        x = self.ffn(x)
        x = self.norm2(residual + self.dropout(x))
        return x

In [None]:
class TRM(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=256,
        n_heads=4,
        d_ff=1024,
        n_layers=4,
        n_reasoning_steps=8,
        n_refinement_steps=16,
        latent_len=32,
        use_attention=True,
        tie_embeddings=True,
    ):
        super().__init__()
        self.d_model = d_model
        self.latent_len = latent_len
        self.n_reasoning_steps = n_reasoning_steps
        self.n_refinement_steps = n_refinement_steps

        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(512, d_model)
        self.embedding_dropout = nn.Dropout(0.1)

        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, use_attention=use_attention)
            for _ in range(n_layers)
        ])

        self.reverse_embedding = nn.Linear(d_model, vocab_size, bias=False)
        if tie_embeddings:
            self.reverse_embedding.weight = self.token_embedding.weight

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, std=0.02)

    def apply_blocks(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x

    def forward_pass(self, x, y, z, mask=None):
        # x, y, z: [batch, seq_len, d_model]
        combined = torch.cat([x, y, z], dim=1)
        combined = self.apply_blocks(combined, mask)

        len_x = x.size(1)
        len_y = y.size(1)
        # split
        x_new = combined[:, :len_x, :]
        y_new = combined[:, len_x: len_x + len_y, :]
        z_new = combined[:, len_x + len_y:, :]
        return x_new, y_new, z_new

    def recursive_reasoning(self, x, y, z, mask=None):
        # Phase 1: reasoning (update z)
        for _ in range(self.n_reasoning_steps):
            _, _, z = self.forward_pass(x, y, z, mask)
        # Phase 2: refinement (update y)
        for _ in range(self.n_refinement_steps):
            _, y, _ = self.forward_pass(x, y, z, mask)
        return y

    def forward(self, question_ids, answer_ids=None, mask=None):
        batch, qlen = question_ids.size()
        device = question_ids.device
        x = self.token_embedding(question_ids) + self.position_embedding(
            torch.arange(qlen, device=device).unsqueeze(0)
        )

        # answer stream initialization
        if answer_ids is not None:
            y = self.token_embedding(answer_ids)
        else:
            # random init
            y = torch.randn(batch, 32, self.d_model, device=device) * 0.02

        # reasoning stream init
        z = torch.randn(batch, self.latent_len, self.d_model, device=device) * 0.02

        y_final = self.recursive_reasoning(x, y, z, mask)
        logits = self.reverse_embedding(y_final)
        return logits

    def generate(self, question_ids, max_length=50, temperature=1.0):
        batch = question_ids.size(0)
        device = question_ids.device
        generated = torch.zeros(batch, 1, dtype=torch.long, device=device)

        for _ in range(max_length):
            logits = self.forward(question_ids, generated)
            next_logits = logits[:, -1, :] / temperature
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=1)
        return generated

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [None]:
def create_trm_att(vocab_size):
    return TRM(
        vocab_size=vocab_size,
        d_model=256,
        n_heads=4,
        d_ff=1024,
        n_layers=4,
        n_reasoning_steps=8,
        n_refinement_steps=16,
        latent_len=32,
        use_attention=True,
        tie_embeddings=True
    )

def create_trm_mlp(vocab_size):
    return TRM(
        vocab_size=vocab_size,
        d_model=256,
        n_heads=4,
        d_ff=1024,
        n_layers=4,
        n_reasoning_steps=8,
        n_refinement_steps=16,
        latent_len=32,
        use_attention=False,
        tie_embeddings=True
    )

In [None]:
# Example: build TRM model and count parameters
vocab_size = 100  # adjust for your tokenizer / dataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = create_trm_att(vocab_size).to(device)
print("# parameters (millions):", model.count_parameters() / 1e6)

# parameters (millions): 3.315712


In [None]:
# Dummy training loop (replace with real data)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

batch_size = 8
seq_len = 10

for epoch in range(5):
    model.train()
    total_loss = 0.0
    for _ in range(10):
        # fake data
        q = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
        a = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
        logits = model(q, a)
        # logits: [batch, answer_len, vocab_size]
        loss = criterion(logits.view(-1, vocab_size), a.view(-1))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} — Loss: {total_loss / 10:.4f}")

Epoch 0 — Loss: 4.1506
Epoch 1 — Loss: 3.4262
Epoch 2 — Loss: 2.8170
Epoch 3 — Loss: 2.3642
Epoch 4 — Loss: 1.9402


In [None]:
# Inference example (again the very dummy question)
model.eval()
q = torch.randint(0, vocab_size, (1, seq_len), device=device)
with torch.no_grad():
    out_ids = model.generate(q, max_length=12, temperature=0.7)
print("Question IDs:", q)
print("Generated IDs:", out_ids)