<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/TRM_Demonstration_20x20_maze.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Understanding Tiny Recursive Models (TRM)

This Colab notebook provides a hands-on implementation of a Tiny Recursive Model (TRM), based on the concepts outlined in the [learn-tiny-recursive-models GitHub repository](https://github.com/vukrosic/learn-tiny-recursive-models).

We will:
1.  Implement the core `RecursiveBlock`.
2.  Build the full `TRM` model.
3.  Run a forward pass to see how it processes a sequence.
4.  Set up a simple training loop to watch the model learn.
5.  Train the model on a more complex task: solving 20x20 mazes.


## Setup

First, let's ensure we have PyTorch installed. Google Colab usually comes with it pre-installed, but it's good practice to run the installation command.


In [None]:
!pip install -q torch
import torch
import torch.nn as nn
import torch.optim as optim

print(f"PyTorch version: {torch.__version__}")


PyTorch version: 2.8.0+cu126


## 8. Leveraging GPU for Faster Training (Optional)

To utilize a GPU, we need to perform two main steps:
1.  **Check for GPU availability:** Determine if a CUDA-enabled GPU is present.
2.  **Move model and data to GPU:** Transfer the model parameters and input/target tensors to the GPU device.

In [None]:
import torch
# Check if CUDA (GPU) is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Now, let's adapt the simple training example to use the GPU if available.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class RecursiveBlock(nn.Module):
    """
    A single recursive block for the TRM.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        # Layer normalization for the state and input
        self.norm_state = nn.LayerNorm(d_model)
        self.norm_input = nn.LayerNorm(d_model)

        # Linear layers to process the combined state and input
        # The input dimension is 2 * d_model because we concatenate state and input
        self.linear1 = nn.Linear(2 * d_model, 2 * d_model)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(2 * d_model, 2 * d_model)

    def forward(self, x, state):
        """
        Forward pass for the recursive block.
        Args:
            x (torch.Tensor): Input tensor for the current step. Shape: (batch_size, d_model)
            state (torch.Tensor): Hidden state from the previous step. Shape: (batch_size, d_model)
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The output and the new state. Both have shape (batch_size, d_model)
        """
        # Normalize the state and input separately
        normalized_state = self.norm_state(state)
        normalized_input = self.norm_input(x)

        # Concatenate along the feature dimension
        combined_input = torch.cat([normalized_state, normalized_input], dim=1)

        # Pass through the linear layers
        hidden = self.linear1(combined_input)
        hidden = self.activation(hidden)
        processed_output = self.linear2(hidden)

        # The magic of TRM: the new state and output are derived from the same processed tensor.
        # This is a simple but effective way to update the state.
        new_state = state + processed_output[:, :self.d_model]
        output = processed_output[:, self.d_model:]

        return output, new_state

class TRM(nn.Module):
    """
    The Tiny Recursive Model.
    """
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model

        # Embedding layer to convert token IDs to vectors
        self.embedding = nn.Embedding(vocab_size, d_model)

        # The recursive block
        self.recursive_block = RecursiveBlock(d_model)

        # A final layer normalization and linear layer to produce logits
        self.norm_out = nn.LayerNorm(d_model)
        self.output_linear = nn.Linear(d_model, vocab_size)

    def forward(self, input_sequence):
        """
        Forward pass for the entire sequence.
        Args:
            input_sequence (torch.Tensor): A sequence of token IDs. Shape: (batch_size, seq_len)
        Returns:
            torch.Tensor: The output logits for each step in the sequence. Shape: (batch_size, seq_len, vocab_size)
        """
        batch_size, seq_len = input_sequence.shape

        # 1. Embed the input sequence
        embedded_input = self.embedding(input_sequence)  # Shape: (batch_size, seq_len, d_model)

        # 2. Initialize the hidden state with zeros
        state = torch.zeros(batch_size, self.d_model, device=input_sequence.device)

        # 3. Process the sequence step-by-step
        outputs = []
        for i in range(seq_len):
            # Get the input for the current time step
            step_input = embedded_input[:, i, :]  # Shape: (batch_size, d_model)

            # Pass through the recursive block
            output, state = self.recursive_block(step_input, state)
            outputs.append(output)

        # 4. Stack the outputs and project to vocabulary size
        # Stack along the sequence dimension
        outputs_tensor = torch.stack(outputs, dim=1)  # Shape: (batch_size, seq_len, d_model)

        # Final normalization and linear projection
        normalized_outputs = self.norm_out(outputs_tensor)
        logits = self.output_linear(normalized_outputs)

        return logits

# --- Training Setup with GPU --- (adapted from previous example)
TRAINING_VOCAB_SIZE = 10
TRAINING_D_MODEL = 16

# Create a new model instance and move it to the device
training_model_gpu = TRM(vocab_size=TRAINING_VOCAB_SIZE, d_model=TRAINING_D_MODEL).to(device)
optimizer_gpu = optim.Adam(training_model_gpu.parameters(), lr=0.01)
criterion_gpu = nn.CrossEntropyLoss()

# --- Data ---
# Move input and target data to the device
input_data_gpu = torch.tensor([[1, 2, 3, 4]]).to(device)
target_data_gpu = torch.tensor([[2, 3, 4, 5]]).to(device)

print(f"Input on device: {input_data_gpu.device}")
print(f"Target on device: {target_data_gpu.device}")
print(f"Model on device: {next(training_model_gpu.parameters()).device}")

# --- Training Loop --- (identical to CPU version, but with GPU tensors)
epochs = 100
print("\n--- Training on GPU/CPU ---")
for epoch in range(epochs):
    optimizer_gpu.zero_grad()
    logits = training_model_gpu(input_data_gpu)
    loss = criterion_gpu(logits.view(-1, TRAINING_VOCAB_SIZE), target_data_gpu.view(-1))
    loss.backward()
    optimizer_gpu.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# --- Inference after training with GPU --- (adapted)
print("\n--- Testing after training (GPU/CPU) ---")
with torch.no_grad():
    test_input_gpu = torch.tensor([[1, 2, 3, 4]]).to(device)
    predictions_gpu = training_model_gpu(test_input_gpu)
    predicted_ids_gpu = torch.argmax(predictions_gpu, dim=2)

    print(f"Input:              {test_input_gpu[0].tolist()}")
    print(f"Predicted sequence: {predicted_ids_gpu[0].tolist()}")
    print(f"Target sequence:    {target_data_gpu[0].tolist()}")
    print("The model has learned to predict the next number in the sequence (on GPU/CPU)!")

Input on device: cuda:0
Target on device: cuda:0
Model on device: cuda:0

--- Training on GPU/CPU ---
Epoch [10/100], Loss: 0.3100
Epoch [20/100], Loss: 0.0621
Epoch [30/100], Loss: 0.0187
Epoch [40/100], Loss: 0.0090
Epoch [50/100], Loss: 0.0058
Epoch [60/100], Loss: 0.0044
Epoch [70/100], Loss: 0.0037
Epoch [80/100], Loss: 0.0032
Epoch [90/100], Loss: 0.0028
Epoch [100/100], Loss: 0.0025

--- Testing after training (GPU/CPU) ---
Input:              [1, 2, 3, 4]
Predicted sequence: [2, 3, 4, 5]
Target sequence:    [2, 3, 4, 5]
The model has learned to predict the next number in the sequence (on GPU/CPU)!


You would apply the same `.to(device)` method to your `maze_model` and all relevant data tensors (`input_seq`, `target_seq`) in the maze-solving example to train it on a GPU.

## 1. The Core Component: `RecursiveBlock`

The fundamental building block of a TRM is the `RecursiveBlock`. It takes an input tensor and a hidden state from the previous step, and produces an output and an updated hidden state. This recursive nature allows it to process sequences step-by-step.

The block consists of:
-   Two Layer Normalization layers for stability.
-   Two Linear layers to transform the concatenated input and state.
-   A GELU activation function for non-linearity.


In [None]:
class RecursiveBlock(nn.Module):
    """
    A single recursive block for the TRM.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        # Layer normalization for the state and input
        self.norm_state = nn.LayerNorm(d_model)
        self.norm_input = nn.LayerNorm(d_model)

        # Linear layers to process the combined state and input
        # The input dimension is 2 * d_model because we concatenate state and input
        self.linear1 = nn.Linear(2 * d_model, 2 * d_model)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(2 * d_model, 2 * d_model)

    def forward(self, x, state):
        """
        Forward pass for the recursive block.
        Args:
            x (torch.Tensor): Input tensor for the current step. Shape: (batch_size, d_model)
            state (torch.Tensor): Hidden state from the previous step. Shape: (batch_size, d_model)
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The output and the new state. Both have shape (batch_size, d_model)
        """
        # Normalize the state and input separately
        normalized_state = self.norm_state(state)
        normalized_input = self.norm_input(x)

        # Concatenate along the feature dimension
        combined_input = torch.cat([normalized_state, normalized_input], dim=1)

        # Pass through the linear layers
        hidden = self.linear1(combined_input)
        hidden = self.activation(hidden)
        processed_output = self.linear2(hidden)

        # The magic of TRM: the new state and output are derived from the same processed tensor.
        # This is a simple but effective way to update the state.
        new_state = state + processed_output[:, :self.d_model]
        output = processed_output[:, self.d_model:]

        return output, new_state


## 2. The TRM Model: Processing Sequences

The full TRM model wraps the `RecursiveBlock`. It initializes the hidden state (usually with zeros) and then iterates through the input sequence, feeding each element into the recursive block one at a time. This step-by-step processing is the core of its operation.

We also add input and output embedding layers to map our vocabulary to the model's dimension and back.


In [None]:
class TRM(nn.Module):
    """
    The Tiny Recursive Model.
    """
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model

        # Embedding layer to convert token IDs to vectors
        self.embedding = nn.Embedding(vocab_size, d_model)

        # The recursive block
        self.recursive_block = RecursiveBlock(d_model)

        # A final layer normalization and linear layer to produce logits
        self.norm_out = nn.LayerNorm(d_model)
        self.output_linear = nn.Linear(d_model, vocab_size)

    def forward(self, input_sequence):
        """
        Forward pass for the entire sequence.
        Args:
            input_sequence (torch.Tensor): A sequence of token IDs. Shape: (batch_size, seq_len)
        Returns:
            torch.Tensor: The output logits for each step in the sequence. Shape: (batch_size, seq_len, vocab_size)
        """
        batch_size, seq_len = input_sequence.shape

        # 1. Embed the input sequence
        embedded_input = self.embedding(input_sequence)  # Shape: (batch_size, seq_len, d_model)

        # 2. Initialize the hidden state with zeros
        state = torch.zeros(batch_size, self.d_model, device=input_sequence.device)

        # 3. Process the sequence step-by-step
        outputs = []
        for i in range(seq_len):
            # Get the input for the current time step
            step_input = embedded_input[:, i, :]  # Shape: (batch_size, d_model)

            # Pass through the recursive block
            output, state = self.recursive_block(step_input, state)
            outputs.append(output)

        # 4. Stack the outputs and project to vocabulary size
        # Stack along the sequence dimension
        outputs_tensor = torch.stack(outputs, dim=1)  # Shape: (batch_size, seq_len, d_model)

        # Final normalization and linear projection
        normalized_outputs = self.norm_out(outputs_tensor)
        logits = self.output_linear(normalized_outputs)

        return logits


## 3. Forward Pass Demonstration

Let's create a dummy input sequence and pass it through our model to see the shapes of the tensors at each step. This helps verify that our implementation is working correctly.


In [None]:
# Model parameters
VOCAB_SIZE = 20
D_MODEL = 32
SEQ_LEN = 5
BATCH_SIZE = 1

# Instantiate the model
model = TRM(vocab_size=VOCAB_SIZE, d_model=D_MODEL)
print("Model Architecture:")
print(model)

# Create a dummy input sequence (batch_size, seq_len)
# These are random token IDs from our vocabulary
dummy_input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN))
print(f"\nDummy Input Shape: {dummy_input.shape}")
print(f"Dummy Input Tensor:\n{dummy_input}")

# Perform a forward pass
output_logits = model(dummy_input)

print(f"\nOutput Logits Shape: {output_logits.shape}")
print("This shape (batch_size, seq_len, vocab_size) is what we expect!")


Model Architecture:
TRM(
  (embedding): Embedding(20, 32)
  (recursive_block): RecursiveBlock(
    (norm_state): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (norm_input): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (linear1): Linear(in_features=64, out_features=64, bias=True)
    (activation): GELU(approximate='none')
    (linear2): Linear(in_features=64, out_features=64, bias=True)
  )
  (norm_out): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (output_linear): Linear(in_features=32, out_features=20, bias=True)
)

Dummy Input Shape: torch.Size([1, 5])
Dummy Input Tensor:
tensor([[10,  2,  4,  2,  8]])

Output Logits Shape: torch.Size([1, 5, 20])
This shape (batch_size, seq_len, vocab_size) is what we expect!


## 4. A Simple Training Example

To show that the model can learn, let's create a simple "next token prediction" task. Our goal is to train the model to predict the next number in a sequence.

-   **Input:** `[1, 2, 3, 4]`
-   **Target:** `[2, 3, 4, 5]`

We'll use Cross-Entropy Loss and the Adam optimizer.


In [None]:
# --- Training Setup ---
# Let's use a slightly bigger vocabulary for this task
TRAINING_VOCAB_SIZE = 10
TRAINING_D_MODEL = 16

# Create a new model instance for training
training_model = TRM(vocab_size=TRAINING_VOCAB_SIZE, d_model=TRAINING_D_MODEL)
optimizer = optim.Adam(training_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# --- Data ---
# Our simple sequence prediction task
input_data = torch.tensor([[1, 2, 3, 4]])      # Shape: (1, 4)
target_data = torch.tensor([[2, 3, 4, 5]])     # Shape: (1, 4)

print(f"Input:  {input_data[0].tolist()}")
print(f"Target: {target_data[0].tolist()}")

# --- Training Loop ---
epochs = 100
for epoch in range(epochs):
    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    logits = training_model(input_data) # Shape: (batch, seq_len, vocab_size)

    # Reshape for the loss function
    # The loss function expects (N, C) where C is number of classes
    # Logits: (1, 4, 10) -> (4, 10)
    # Target: (1, 4) -> (4)
    loss = criterion(logits.view(-1, TRAINING_VOCAB_SIZE), target_data.view(-1))

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# --- Inference after training ---
print("\n--- Testing after training ---")
with torch.no_grad():
    test_input = torch.tensor([[1, 2, 3, 4]])
    predictions = training_model(test_input)
    # Get the predicted token ID by finding the max logit
    predicted_ids = torch.argmax(predictions, dim=2)

    print(f"Input:           {test_input[0].tolist()}")
    print(f"Predicted sequence: {predicted_ids[0].tolist()}")
    print(f"Target sequence:    {target_data[0].tolist()}")
    print("\nThe model has learned to predict the next number in the sequence!")


Input:  [1, 2, 3, 4]
Target: [2, 3, 4, 5]
Epoch [10/100], Loss: 0.3315
Epoch [20/100], Loss: 0.0663
Epoch [30/100], Loss: 0.0192
Epoch [40/100], Loss: 0.0091
Epoch [50/100], Loss: 0.0059
Epoch [60/100], Loss: 0.0045
Epoch [70/100], Loss: 0.0038
Epoch [80/100], Loss: 0.0033
Epoch [90/100], Loss: 0.0029
Epoch [100/100], Loss: 0.0026

--- Testing after training ---
Input:           [1, 2, 3, 4]
Predicted sequence: [2, 3, 4, 5]
Target sequence:    [2, 3, 4, 5]

The model has learned to predict the next number in the sequence!


## 5. Advanced Example: Solving 20x20 Mazes

Now for a more challenging task. Let's train the TRM to solve larger 20x20 mazes. This requires the model to process a much longer sequence (400 maze tokens + path tokens) and maintain its state to find the solution.

**The Task:** The model will be given a sequence representing the maze structure, followed by the starting position of the path. Its goal is to predict the rest of the coordinate sequence that solves the maze.

**Data Representation:**
We will convert the maze and its solution path into a single sequence of integers (tokens).

- **Vocabulary:**
- `0`: Wall (`#`)
- `1`: Path (`.`)
- `2`: Start (`S`)
- `3`: End (`E`)
- `4`: Separator (`|`) - A special token to divide the maze layout from the path coordinates.
- `5-404`: Path Coordinates - Each of the 400 cells `(row, col)` is mapped to a unique token `5 + row * 20 + col`.

- **Sequence Format:**
- The model is trained on a single continuous sequence. The input is the sequence up to step `n-1`, and the target is the sequence up to step `n` (shifted).
- We only care about predicting the path. Therefore, we will use a **loss mask** to ignore the model's predictions for the maze layout part of the sequence during training.


In [None]:
# --- Maze Data Setup ---

# Define a 20x20 maze and its solutions
# 0: Wall, 1: Path, 2: Start, 3: End
MAZE_DATASET = [
    {
        "maze": [
            [2, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
            [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1],
            [1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1],
            [1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1],
            [1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],
            [1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
            [1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1],
            [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1],
            [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1],
            [1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1],
            [1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1],
            [1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
            [1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1],
            [1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
            [1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1],
            [1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
            [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 3]
        ],
        "path": [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (3, 2), (3, 3), (4, 3), (4, 2), (5, 2), (5, 3), (5, 4), (5, 5), (6, 5), (7, 5), (7, 4), (7, 3), (7, 2), (7, 1), (8, 1), (9, 1), (9, 0), (10, 0), (11, 0), (11, 1), (11, 2), (11, 3), (12, 3), (13, 3), (13, 4), (13, 5), (14, 5), (15, 5), (15, 4), (15, 3), (16, 3), (17, 3), (17, 4), (17, 5), (17, 6), (17, 7), (17, 8), (17, 9), (18, 9), (19, 9), (19, 10), (19, 11), (19, 12), (19, 13), (18, 13), (17, 13), (16, 13), (15, 13), (14, 13), (13, 13), (12, 13), (11, 13), (10, 13), (9, 13), (8, 13), (7, 13), (6, 13), (5, 13), (4, 13), (3, 13), (2, 13), (1, 13), (0, 13), (0, 14), (0, 15), (0, 16), (0, 17), (0, 18), (0, 19), (1, 19), (2, 19), (3, 19), (4, 19), (5, 19), (5, 18), (6, 18), (7, 18), (8, 18), (9, 18), (9, 17), (10, 17), (11, 17), (11, 18), (12, 18), (13, 18), (14, 18), (15, 18), (16, 18), (16, 19), (17, 19), (18, 19), (19, 19)]
    }
]

# --- Vocabulary and Preprocessing ---
WALL, PATH, START, END, SEP = 0, 1, 2, 3, 4
PATH_TOKEN_OFFSET = 5
MAZE_SIZE = 20 # <-- Updated size
MAZE_TOKENS = MAZE_SIZE * MAZE_SIZE

def preprocess_maze_data(dataset):
    sequences = []
    for item in dataset:
        maze_flat = [token for row in item["maze"] for token in row]
        path_tokens = [PATH_TOKEN_OFFSET + r * MAZE_SIZE + c for r, c in item["path"]]

        full_sequence = maze_flat + [SEP] + path_tokens
        sequences.append(torch.tensor(full_sequence))
    return sequences

training_sequences = preprocess_maze_data(MAZE_DATASET)

# Let's inspect the first preprocessed sequence
print("Original Maze (first example):")
for row in MAZE_DATASET[0]['maze']:
    print("".join([{0:'#', 1:'.', 2:'S', 3:'E'}[c] for c in row]))

print(f"\nPreprocessed sequence length: {len(training_sequences[0])}")


Original Maze (first example):
S.#...#...#.........
..#.#.#.#.#.#######.
.##.#...#...#.....#.
....#.#.#.#.#.###.#.
.#.##.#...#...#...#.
.#....###.#.#.#.#...
.####.....#.#.#.#.#.
......#.#...#...#.#.
#.#.#.#.#.###.#.#.#.
..#.#.#.#.....#.#...
.##.#.#...#.#.#.###.
....#...#.#.#.#.....
.##.#.#.#.#.#.#.#.#.
......#.#.....#.#.#.
.####.#.#.#.#.#.#.#.
....#.#...#.#...#.#.
.##.#.###.#.#.#.#...
..........#...#.#.#.
.######.#.#.#.#.#.#.
........#...#......E

Preprocessed sequence length: 498


In [None]:
# --- Maze Model Training ---
MAZE_VOCAB_SIZE = PATH_TOKEN_OFFSET + MAZE_TOKENS # 5 special tokens + 400 path tokens
MAZE_D_MODEL = 64 # Increased model size for complexity
lr = 0.001 # Use a smaller learning rate

maze_model = TRM(vocab_size=MAZE_VOCAB_SIZE, d_model=MAZE_D_MODEL)
maze_optimizer = optim.Adam(maze_model.parameters(), lr=lr)
maze_criterion = nn.CrossEntropyLoss(reduction='none') # Use 'none' to apply mask later

epochs = 2000 # Increased epochs for this harder task
print("Starting maze training... (This will take a few minutes)")

for epoch in range(epochs):
    total_loss = 0
    total_tokens = 0

    for seq in training_sequences:
        maze_optimizer.zero_grad()

        # Prepare input and target (shifted input)
        input_seq = seq[:-1].unsqueeze(0)
        target_seq = seq[1:].unsqueeze(0)

        # Define the loss mask: we only care about predicting the path tokens.
        # The path starts after the maze layout (400 tokens) and the SEP token.
        # The target for the SEP token is the first path token, so we start the mask there.
        loss_mask = torch.zeros_like(target_seq, dtype=torch.float)
        loss_mask[:, MAZE_TOKENS:] = 1.0

        # Forward pass
        logits = maze_model(input_seq)

        # Calculate loss
        loss = maze_criterion(logits.view(-1, MAZE_VOCAB_SIZE), target_seq.view(-1))
        masked_loss = loss * loss_mask.view(-1)

        # Average the loss over the number of path tokens only
        final_loss = masked_loss.sum() / loss_mask.sum()

        # Backward pass and optimization
        final_loss.backward()
        maze_optimizer.step()

        total_loss += masked_loss.sum().item()
        total_tokens += loss_mask.sum().item()

    avg_loss = total_loss / total_tokens
    if (epoch + 1) % 200 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}")

print("Maze training finished.")


Starting maze training... (This will take a few minutes)
Epoch [200/2000], Average Loss: 0.0316
Epoch [400/2000], Average Loss: 0.0098
Epoch [600/2000], Average Loss: 0.0049
Epoch [800/2000], Average Loss: 0.0030
Epoch [1000/2000], Average Loss: 0.0020
Epoch [1200/2000], Average Loss: 0.0014
Epoch [1400/2000], Average Loss: 0.0011
Epoch [1600/2000], Average Loss: 0.0008
Epoch [1800/2000], Average Loss: 0.0007
Epoch [2000/2000], Average Loss: 0.0005
Maze training finished.


In [None]:
# --- Inference on a Maze ---
def solve_maze(model, maze_grid, max_len=400): # Increased max_len
    model.eval() # Set model to evaluation mode

    maze_flat = [token for row in maze_grid for token in row]

    # Find start position to begin generation
    start_pos_flat = maze_flat.index(START)
    start_r, start_c = divmod(start_pos_flat, MAZE_SIZE)
    start_path_token = PATH_TOKEN_OFFSET + start_r * MAZE_SIZE + start_c

    # Initial input: maze layout + separator + start token
    input_seq = torch.tensor(maze_flat + [SEP] + [start_path_token]).unsqueeze(0)

    generated_path_tokens = [start_path_token]

    with torch.no_grad():
        for _ in range(max_len - 1):
            logits = model(input_seq)

            # Get the prediction for the very last token in the sequence
            next_token_logits = logits[:, -1, :]
            predicted_token = torch.argmax(next_token_logits, dim=1).item()

            # Append the prediction and update the input for the next step
            generated_path_tokens.append(predicted_token)
            input_seq = torch.cat([input_seq, torch.tensor([[predicted_token]])], dim=1)

            # Stop if we predict the end token
            end_pos_flat = maze_flat.index(END)
            end_r, end_c = divmod(end_pos_flat, MAZE_SIZE)
            end_path_token = PATH_TOKEN_OFFSET + end_r * MAZE_SIZE + end_c
            if predicted_token == end_path_token:
                break

    # Convert token IDs back to (row, col) coordinates
    path_coords = []
    for token in generated_path_tokens:
        flat_pos = token - PATH_TOKEN_OFFSET
        r, c = divmod(flat_pos, MAZE_SIZE)
        path_coords.append((r, c))

    return path_coords

# Test with the first maze from our dataset
test_maze_grid = MAZE_DATASET[0]["maze"]
predicted_path = solve_maze(maze_model, test_maze_grid)

print("Maze to solve:")
for row in test_maze_grid:
     print("".join([{0:'#', 1:'.', 2:'S', 3:'E'}[c] for c in row]))

print(f"\nPredicted Path has {len(predicted_path)} steps.")
print(f"Correct Path has {len(MAZE_DATASET[0]['path'])} steps.")

# Visualize the path
print("\nVisualized solution (* = predicted path):")
solution_grid = [[' ' for _ in range(MAZE_SIZE)] for _ in range(MAZE_SIZE)]
for r in range(MAZE_SIZE):
    line = ""
    for c in range(MAZE_SIZE):
        if test_maze_grid[r][c] == WALL:
            line += '#'
        elif (r,c) in predicted_path:
            line += '*'
        else:
            line += '.'
    print(line)


Maze to solve:
S.#...#...#.........
..#.#.#.#.#.#######.
.##.#...#...#.....#.
....#.#.#.#.#.###.#.
.#.##.#...#...#...#.
.#....###.#.#.#.#...
.####.....#.#.#.#.#.
......#.#...#...#.#.
#.#.#.#.#.###.#.#.#.
..#.#.#.#.....#.#...
.##.#.#...#.#.#.###.
....#...#.#.#.#.....
.##.#.#.#.#.#.#.#.#.
......#.#.....#.#.#.
.####.#.#.#.#.#.#.#.
....#.#...#.#...#.#.
.##.#.###.#.#.#.#...
..........#...#.#.#.
.######.#.#.#.#.#.#.
........#...#......E

Predicted Path has 97 steps.
Correct Path has 97 steps.

Visualized solution (* = predicted path):
*.#...#...#..*******
*.#.#.#.#.#.#######*
*##.#...#...#*....#*
****#.#.#.#.#*###.#*
.#*##.#...#..*#...#*
.#****###.#.#*#.#.**
.####*....#.#*#.#.#.
.*****#.#...#*..#.#.
#*#.#.#.#.###*#.#.#.
**#.#.#.#....*#.#**.
*##.#.#...#.#*#.###.
****#...#.#.#*#..**.
.##*#.#.#.#.#*#.#.#.
...***#.#....*#.#.#.
.####*#.#.#.#*#.#.#.
...*#*#...#.#*..#.#.
.##*#.###.#.#*#.#.**
...*******#..*#.#.#*
.######.#*#.#*#.#.#*
........#***#*.....*


The model learns to generate the sequence of coordinates, effectively solving the maze. This demonstrates the TRM's ability to handle more structured and much longer sequence-to-sequence tasks that require understanding a context (the maze layout) to generate a relevant output (the path).


## 6. Key Differences from Transformers

As summarized in the original repository, TRMs differ from Transformers in several key ways:

-   **Computation:** TRMs are **recursive** and process sequences step-by-step, making their complexity linear, O(L). Transformers use **self-attention**, which is parallelizable but has quadratic complexity, O(LÂ²).
-   **State Management:** TRMs explicitly manage a **hidden state** that evolves over time. Transformers are stateless and recompute context from scratch at each layer.
-   **Positional Information:** TRMs inherently understand sequence order due to their sequential processing. Transformers require explicit **positional encodings** to be added to their inputs.


## 7. Conclusion

This notebook provided a brief, practical introduction to Tiny Recursive Models. We implemented the core components in PyTorch, demonstrated that a simple TRM can learn a basic sequence prediction task, and showed its effectiveness on a more complex 20x20 maze-solving problem.

TRMs offer an interesting alternative to Transformers, particularly for applications where computational efficiency and explicit state management are important.
