<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/TRM_Demonstration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Understanding Tiny Recursive Models (TRM)

This Colab notebook provides a hands-on implementation of a Tiny Recursive Model (TRM), based on the concepts outlined in the [learn-tiny-recursive-models GitHub repository](https://github.com/vukrosic/learn-tiny-recursive-models).

We will:
1.  Implement the core `RecursiveBlock`.
2.  Build the full `TRM` model.
3.  Run a forward pass to see how it processes a sequence.
4.  Set up a simple training loop to watch the model learn.


## Setup

First, let's ensure we have PyTorch installed. Google Colab usually comes with it pre-installed, but it's good practice to run the installation command.


In [None]:
!pip install -q torch
import torch
import torch.nn as nn
import torch.optim as optim

print(f"PyTorch version: {torch.__version__}")


PyTorch version: 2.8.0+cu126


## 1. The Core Component: `RecursiveBlock`

The fundamental building block of a TRM is the `RecursiveBlock`. It takes an input tensor and a hidden state from the previous step, and produces an output and an updated hidden state. This recursive nature allows it to process sequences step-by-step.

The block consists of:
-   Two Layer Normalization layers for stability.
-   Two Linear layers to transform the concatenated input and state.
-   A GELU activation function for non-linearity.


In [None]:
class RecursiveBlock(nn.Module):
    """
    A single recursive block for the TRM.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        # Layer normalization for the state and input
        self.norm_state = nn.LayerNorm(d_model)
        self.norm_input = nn.LayerNorm(d_model)

        # Linear layers to process the combined state and input
        # The input dimension is 2 * d_model because we concatenate state and input
        self.linear1 = nn.Linear(2 * d_model, 2 * d_model)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(2 * d_model, 2 * d_model)

    def forward(self, x, state):
        """
        Forward pass for the recursive block.
        Args:
            x (torch.Tensor): Input tensor for the current step. Shape: (batch_size, d_model)
            state (torch.Tensor): Hidden state from the previous step. Shape: (batch_size, d_model)
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The output and the new state. Both have shape (batch_size, d_model)
        """
        # Normalize the state and input separately
        normalized_state = self.norm_state(state)
        normalized_input = self.norm_input(x)

        # Concatenate along the feature dimension
        combined_input = torch.cat([normalized_state, normalized_input], dim=1)

        # Pass through the linear layers
        hidden = self.linear1(combined_input)
        hidden = self.activation(hidden)
        processed_output = self.linear2(hidden)

        # The magic of TRM: the new state and output are derived from the same processed tensor.
        # This is a simple but effective way to update the state.
        new_state = state + processed_output[:, :self.d_model]
        output = processed_output[:, self.d_model:]

        return output, new_state


## 2. The TRM Model: Processing Sequences

The full TRM model wraps the `RecursiveBlock`. It initializes the hidden state (usually with zeros) and then iterates through the input sequence, feeding each element into the recursive block one at a time. This step-by-step processing is the core of its operation.

We also add input and output embedding layers to map our vocabulary to the model's dimension and back.


In [None]:
class TRM(nn.Module):
    """
    The Tiny Recursive Model.
    """
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model

        # Embedding layer to convert token IDs to vectors
        self.embedding = nn.Embedding(vocab_size, d_model)

        # The recursive block
        self.recursive_block = RecursiveBlock(d_model)

        # A final layer normalization and linear layer to produce logits
        self.norm_out = nn.LayerNorm(d_model)
        self.output_linear = nn.Linear(d_model, vocab_size)

    def forward(self, input_sequence):
        """
        Forward pass for the entire sequence.
        Args:
            input_sequence (torch.Tensor): A sequence of token IDs. Shape: (batch_size, seq_len)
        Returns:
            torch.Tensor: The output logits for each step in the sequence. Shape: (batch_size, seq_len, vocab_size)
        """
        batch_size, seq_len = input_sequence.shape

        # 1. Embed the input sequence
        embedded_input = self.embedding(input_sequence)  # Shape: (batch_size, seq_len, d_model)

        # 2. Initialize the hidden state with zeros
        state = torch.zeros(batch_size, self.d_model, device=input_sequence.device)

        # 3. Process the sequence step-by-step
        outputs = []
        for i in range(seq_len):
            # Get the input for the current time step
            step_input = embedded_input[:, i, :]  # Shape: (batch_size, d_model)

            # Pass through the recursive block
            output, state = self.recursive_block(step_input, state)
            outputs.append(output)

        # 4. Stack the outputs and project to vocabulary size
        # Stack along the sequence dimension
        outputs_tensor = torch.stack(outputs, dim=1)  # Shape: (batch_size, seq_len, d_model)

        # Final normalization and linear projection
        normalized_outputs = self.norm_out(outputs_tensor)
        logits = self.output_linear(normalized_outputs)

        return logits


## 3. Forward Pass Demonstration

Let's create a dummy input sequence and pass it through our model to see the shapes of the tensors at each step. This helps verify that our implementation is working correctly.


In [None]:
# Model parameters
VOCAB_SIZE = 20
D_MODEL = 32
SEQ_LEN = 5
BATCH_SIZE = 1

# Instantiate the model
model = TRM(vocab_size=VOCAB_SIZE, d_model=D_MODEL)
print("Model Architecture:")
print(model)

# Create a dummy input sequence (batch_size, seq_len)
# These are random token IDs from our vocabulary
dummy_input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN))
print(f"\nDummy Input Shape: {dummy_input.shape}")
print(f"Dummy Input Tensor:\n{dummy_input}")

# Perform a forward pass
output_logits = model(dummy_input)

print(f"\nOutput Logits Shape: {output_logits.shape}")
print("This shape (batch_size, seq_len, vocab_size) is what we expect!")


Model Architecture:
TRM(
  (embedding): Embedding(20, 32)
  (recursive_block): RecursiveBlock(
    (norm_state): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (norm_input): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (linear1): Linear(in_features=64, out_features=64, bias=True)
    (activation): GELU(approximate='none')
    (linear2): Linear(in_features=64, out_features=64, bias=True)
  )
  (norm_out): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (output_linear): Linear(in_features=32, out_features=20, bias=True)
)

Dummy Input Shape: torch.Size([1, 5])
Dummy Input Tensor:
tensor([[ 1,  1, 15,  1, 19]])

Output Logits Shape: torch.Size([1, 5, 20])
This shape (batch_size, seq_len, vocab_size) is what we expect!


## 4. A Simple Training Example

To show that the model can learn, let's create a simple "next token prediction" task. Our goal is to train the model to predict the next number in a sequence.

-   **Input:** `[1, 2, 3, 4]`
-   **Target:** `[2, 3, 4, 5]`

We'll use Cross-Entropy Loss and the Adam optimizer.


In [None]:
# --- Training Setup ---
# Let's use a slightly bigger vocabulary for this task
TRAINING_VOCAB_SIZE = 10
TRAINING_D_MODEL = 16

# Create a new model instance for training
training_model = TRM(vocab_size=TRAINING_VOCAB_SIZE, d_model=TRAINING_D_MODEL)
optimizer = optim.Adam(training_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# --- Data ---
# Our simple sequence prediction task
input_data = torch.tensor([[1, 2, 3, 4]])      # Shape: (1, 4)
target_data = torch.tensor([[2, 3, 4, 5]])     # Shape: (1, 4)

print(f"Input:  {input_data[0].tolist()}")
print(f"Target: {target_data[0].tolist()}")

# --- Training Loop ---
epochs = 100
for epoch in range(epochs):
    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    logits = training_model(input_data) # Shape: (batch, seq_len, vocab_size)

    # Reshape for the loss function
    # The loss function expects (N, C) where C is number of classes
    # Logits: (1, 4, 10) -> (4, 10)
    # Target: (1, 4) -> (4)
    loss = criterion(logits.view(-1, TRAINING_VOCAB_SIZE), target_data.view(-1))

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# --- Inference after training ---
print("\n--- Testing after training ---")
with torch.no_grad():
    test_input = torch.tensor([[1, 2, 3, 4]])
    predictions = training_model(test_input)
    # Get the predicted token ID by finding the max logit
    predicted_ids = torch.argmax(predictions, dim=2)

    print(f"Input:           {test_input[0].tolist()}")
    print(f"Predicted sequence: {predicted_ids[0].tolist()}")
    print(f"Target sequence:    {target_data[0].tolist()}")
    print("\nThe model has learned to predict the next number in the sequence!")


Input:  [1, 2, 3, 4]
Target: [2, 3, 4, 5]
Epoch [10/100], Loss: 0.3444
Epoch [20/100], Loss: 0.0682
Epoch [30/100], Loss: 0.0196
Epoch [40/100], Loss: 0.0091
Epoch [50/100], Loss: 0.0058
Epoch [60/100], Loss: 0.0044
Epoch [70/100], Loss: 0.0037
Epoch [80/100], Loss: 0.0031
Epoch [90/100], Loss: 0.0028
Epoch [100/100], Loss: 0.0024

--- Testing after training ---
Input:           [1, 2, 3, 4]
Predicted sequence: [2, 3, 4, 5]
Target sequence:    [2, 3, 4, 5]

The model has learned to predict the next number in the sequence!


## 5. Key Differences from Transformers

As summarized in the original repository, TRMs differ from Transformers in several key ways:

-   **Computation:** TRMs are **recursive** and process sequences step-by-step, making their complexity linear, O(L). Transformers use **self-attention**, which is parallelizable but has quadratic complexity, O(LÂ²).
-   **State Management:** TRMs explicitly manage a **hidden state** that evolves over time. Transformers are stateless and recompute context from scratch at each layer.
-   **Positional Information:** TRMs inherently understand sequence order due to their sequential processing. Transformers require explicit **positional encodings** to be added to their inputs.


In [None]:
print(f"Number of lines in dummy_input: {dummy_input.shape[0]}")

Number of lines in dummy_input: 1


## Conclusion

This notebook provided a brief, practical introduction to Tiny Recursive Models. We implemented the core components in PyTorch and demonstrated that even a simple TRM can learn a basic sequence prediction task.

TRMs offer an interesting alternative to Transformers, particularly for applications where computational efficiency and explicit state management are important.
