<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/Full_TRM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================
# Tiny Recursive Model (TRM) for 5×5 mazes
# With:
#   - Deep supervision (loss at every step)
#   - Halting probability (optional)
#   - Separate embeddings y (solution) and z (latent state)
# ============================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import matplotlib.pyplot as plt
import networkx as nx

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =====================================================
# 1. MAZE GENERATION + SHORTEST PATH LABELS
# =====================================================

def generate_maze_5x5():
    """
    Random 5×5 maze with walls, ensuring a valid path exists.
    """
    maze = np.zeros((5,5), dtype=np.int32)

    # randomly place walls
    for i in range(5):
        for j in range(5):
            if random.random() < 0.25:
                maze[i,j] = 1

    # choose start/goal
    sx, sy = random.randint(0,4), random.randint(0,4)
    gx, gy = random.randint(0,4), random.randint(0,4)
    maze[sx,sy] = 2
    maze[gx,gy] = 3

    # BFS shortest path
    G = nx.grid_graph([5,5])
    G.remove_edges_from([(u,v) for u,v in G.edges if maze[u] == 1 or maze[v] == 1])

    try:
        sp = nx.shortest_path(G, (sx,sy), (gx,gy))
        label = np.zeros((5,5), dtype=np.float32)
        for x,y in sp:
            label[x,y] = 1.0
        return maze.astype(np.float32), label
    except:
        return generate_maze_5x5()

class MazeDataset(Dataset):
    def __init__(self, n=2000):
        self.mazes = []
        self.labels = []
        for _ in range(n):
            maze, label = generate_maze_5x5()
            self.mazes.append(maze)
            self.labels.append(label)

    def __len__(self): return len(self.mazes)

    def __getitem__(self, idx):
        maze = self.mazes[idx][None,:,:]  # (1,5,5)
        label = self.labels[idx][None,:,:] # (1,5,5)
        return torch.tensor(maze), torch.tensor(label)


# =====================================================
# 2. TINY RECURSIVE MODEL (TRM)
# =====================================================

class TRM(nn.Module):
    def __init__(self, hidden_dim=64, halting=True):
        super().__init__()
        self.halting = halting

        # Core tiny network
        self.f = nn.Sequential(
            nn.Conv2d(1+1+16, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU()
        )

        # y-head (solution logits)
        self.y_head = nn.Conv2d(hidden_dim, 1, kernel_size=1)

        # latent z update: map hidden → latent (16 channels)
        self.z_head = nn.Conv2d(hidden_dim, 16, kernel_size=1)

        # halting head
        if halting:
            self.halt_head = nn.Conv2d(hidden_dim, 1, kernel_size=1)

    def forward(self, maze, y, z):
        """
        maze: (B,1,5,5)
        y:    (B,1,5,5)
        z:    (B,16,5,5)
        """
        inp = torch.cat([maze, y, z], dim=1)
        h = self.f(inp)

        logits_y = self.y_head(h)
        new_y = torch.sigmoid(logits_y)

        new_z = self.z_head(h)

        if self.halting:
            halt_logits = self.halt_head(h)
            halt_p = torch.sigmoid(halt_logits)
        else:
            halt_p = torch.zeros_like(new_y)

        return logits_y, new_y, new_z, halt_p


# =====================================================
# 3. RECURSIVE LOOP WITH DEEP SUPERVISION
# =====================================================

def recursive_pass(model, maze, label, T=5, use_halting=False):
    """
    Deep supervision: loss at each step t=1..T.
    Optional halting mechanism.
    """
    B = maze.size(0)
    y = torch.zeros_like(maze)     # (B,1,5,5)
    z = torch.zeros((B,16,5,5), device=maze.device)

    total_loss = 0.0
    criterion = nn.BCEWithLogitsLoss()

    halting_mass = torch.zeros((B,1,5,5), device=maze.device)
    final_output = None

    for t in range(T):
        logits_y, new_y, new_z, halt_p = model(maze, y, z)

        # deep supervision loss
        loss_t = criterion(logits_y, label)
        total_loss += loss_t

        # halting accumulation
        if use_halting:
            still_running = (halting_mass < 1.0).float()
            halting_mass = halting_mass + halt_p * still_running
            final_output = logits_y

            # if all halted → early stop
            if (halting_mass >= 1.0).all():
                return total_loss / (t+1), final_output

        # update y,z
        y, z = new_y, new_z

        final_output = logits_y

    return total_loss / T, final_output


# =====================================================
# 4. TRAINING
# =====================================================

train_ds = MazeDataset(2000)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)

model = TRM(hidden_dim=64, halting=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 10
T = 5

for epoch in range(EPOCHS):
    model.train()
    total = 0
    for maze, label in train_loader:
        maze = maze.to(device)
        label = label.to(device)

        loss, _ = recursive_pass(model, maze, label, T=T, use_halting=True)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += loss.item() * maze.size(0)

    print(f"Epoch {epoch+1}, loss = {total / len(train_loader.dataset):.4f}")


# =====================================================
# 5. INFERENCE
# =====================================================

def infer(model, maze, T=10, halting=True):
    """
    maze: (1,1,5,5)
    """
    model.eval()
    with torch.no_grad():
        B = maze.size(0)
        y = torch.zeros_like(maze)
        z = torch.zeros((B,16,5,5), device=maze.device)

        halting_mass = torch.zeros_like(maze)
        final_logits = None

        for t in range(T):
            logits_y, new_y, new_z, halt_p = model(maze, y, z)
            y, z = new_y, new_z
            final_logits = logits_y

            if halting:
                still_running = (halting_mass < 1).float()
                halting_mass += halt_p * still_running
                if (halting_mass >= 1).all():
                    break

        probs = torch.sigmoid(final_logits).cpu().numpy()[0,0]
        return (probs > 0.5).astype(int), probs


# =====================================================
# 6. DEMO ON UNSEEN MAZE
# =====================================================

maze, label = generate_maze_5x5()
maze_t = torch.tensor(maze[None,None,:,:], device=device)

pred_map, prob = infer(model, maze_t)

print("Input maze:")
print(maze[0])

print("\nGround truth path:")
print(label)

print("\nPredicted path:")
print(pred_map)


Epoch 1, loss = 0.4852
Epoch 2, loss = 0.3465
Epoch 3, loss = 0.2699
Epoch 4, loss = 0.2417
Epoch 5, loss = 0.2172
Epoch 6, loss = 0.1990
Epoch 7, loss = 0.1834
Epoch 8, loss = 0.1746
Epoch 9, loss = 0.1667
Epoch 10, loss = 0.1575
Input maze:
[3. 2. 0. 0. 0.]

Ground truth path:
[[1. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]

Predicted path:
[[1 1 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]


In [None]:
train_ds = MazeDataset(2000)
print("Maze 1:")
print(train_ds[0][0].squeeze().numpy())

print("\nMaze 10:")
print(train_ds[9][0].squeeze().numpy())

Maze 1:
[[0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [1. 1. 0. 0. 1.]
 [3. 1. 0. 1. 0.]
 [0. 0. 2. 1. 1.]]

Maze 10:
[[1. 0. 0. 2. 0.]
 [1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [1. 3. 0. 0. 0.]
 [1. 0. 0. 0. 0.]]


In [None]:
train_ds = MazeDataset(2000)

# Display Maze 1 and its solution path
maze1, label1 = train_ds[0]
print("Maze 1:")
print(maze1.squeeze().numpy())
print("Ground truth path for Maze 1:")
print(label1.squeeze().numpy())

# Display Maze 10 and its solution path
maze10, label10 = train_ds[9]
print("\nMaze 10:")
print(maze10.squeeze().numpy())
print("Ground truth path for Maze 10:")
print(label10.squeeze().numpy())

# Generate and display an unseen maze and its solution path
unseen_maze, unseen_label = generate_maze_5x5()
unseen_maze_t = torch.tensor(unseen_maze[None,None,:,:], device=device)

pred_map, prob = infer(model, unseen_maze_t)

print("\nUnseen Maze:")
print(unseen_maze)

print("\nGround truth path for Unseen Maze:")
print(unseen_label)

print("\nPredicted path for Unseen Maze:")
print(pred_map)

Maze 1:
[[2. 0. 0. 0. 0.]
 [0. 3. 0. 0. 0.]
 [1. 1. 0. 1. 1.]
 [1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]]
Ground truth path for Maze 1:
[[1. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]

Maze 10:
[[1. 1. 0. 0. 0.]
 [0. 3. 1. 0. 0.]
 [1. 0. 0. 0. 0.]
 [2. 0. 0. 0. 1.]
 [0. 0. 1. 0. 0.]]
Ground truth path for Maze 10:
[[0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [1. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]

Unseen Maze:
[[1. 1. 0. 0. 1.]
 [0. 0. 0. 3. 0.]
 [1. 0. 0. 0. 2.]
 [1. 0. 0. 1. 1.]
 [0. 0. 0. 1. 0.]]

Ground truth path for Unseen Maze:
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]

Predicted path for Unseen Maze:
[[0 0 0 0 0]
 [0 0 0 1 0]
 [0 0 0 1 1]
 [0 0 0 0 0]
 [0 0 0 0 0]]


In [None]:
import numpy as np

found_in_training = False
for i in range(len(train_ds)):
    # train_ds mazes are (1,5,5), unseen_maze is (5,5)
    if np.array_equal(train_ds[i][0].squeeze().numpy(), unseen_maze):
        found_in_training = True
        break

if found_in_training:
    print("The unseen maze WAS found in the training dataset.")
else:
    print("The unseen maze was NOT found in the training dataset.")

The unseen maze was NOT found in the training dataset.
