<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/harder_5x5_maze_TRM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ===========================================================
# HARD 5x5 TRM NOTEBOOK — min path length ≥ 7
# ===========================================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import matplotlib.pyplot as plt
import networkx as nx

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===========================================================
# 1. HARD 5×5 MAZE GENERATOR (min path >= 7)
# ===========================================================

def generate_hard_maze_5x5(min_path_len=7):
    """
    Hard maze generator:
      - 5×5 grid
      - 40–55% walls
      - ensures shortest path length >= min_path_len
    """

    while True:
        maze = np.zeros((5,5), dtype=np.int32)

        # randomly place walls
        for i in range(5):
            for j in range(5):
                if random.random() < random.uniform(0.40, 0.55):
                    maze[i,j] = 1

        # random start + goal
        sx, sy = random.randint(0,4), random.randint(0,4)
        gx, gy = random.randint(0,4), random.randint(0,4)

        # ensure start/goal are empty
        maze[sx,sy] = 0
        maze[gx,gy] = 0

        # mark start & goal for visualization (not used as model input)
        maze_vis = maze.copy()
        maze_vis[sx,sy] = 2
        maze_vis[gx,gy] = 3

        # build graph with walls removed
        G = nx.grid_graph([5,5])
        G.remove_edges_from([(u,v) for u,v in G.edges if maze[u] == 1 or maze[v] == 1])

        # shortest path
        try:
            sp = nx.shortest_path(G, (sx,sy), (gx,gy))
            if len(sp) < min_path_len:
                continue  # too easy

            # label (binary path)
            label = np.zeros((5,5), dtype=np.float32)
            for x,y in sp:
                label[x,y] = 1.0

            return maze_vis.astype(np.float32), label

        except:
            # invalid maze → regenerate
            continue


class HardMazeDataset(Dataset):
    def __init__(self, n=2000, min_path_len=7):
        self.mazes = []
        self.labels = []
        for _ in range(n):
            maze, label = generate_hard_maze_5x5(min_path_len)
            self.mazes.append(maze)
            self.labels.append(label)

    def __len__(self): return len(self.mazes)

    def __getitem__(self, idx):
        maze = self.mazes[idx][None,:,:]  # (1,5,5)
        label = self.labels[idx][None,:,:] # (1,5,5)
        return torch.tensor(maze), torch.tensor(label)


# ===========================================================
# 2. TRM MODEL (unchanged)
# ===========================================================

class TRM(nn.Module):
    def __init__(self, hidden_dim=64, halting=True):
        super().__init__()
        self.halting = halting

        self.f = nn.Sequential(
            nn.Conv2d(1+1+16, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU()
        )

        self.y_head = nn.Conv2d(hidden_dim, 1, kernel_size=1)
        self.z_head = nn.Conv2d(hidden_dim, 16, kernel_size=1)

        if halting:
            self.halt_head = nn.Conv2d(hidden_dim, 1, kernel_size=1)

    def forward(self, maze, y, z):
        inp = torch.cat([maze, y, z], dim=1)
        h = self.f(inp)

        logits_y = self.y_head(h)
        new_y = torch.sigmoid(logits_y)
        new_z = self.z_head(h)

        if self.halting:
            halt_logits = self.halt_head(h)
            halt_p = torch.sigmoid(halt_logits)
        else:
            halt_p = torch.zeros_like(new_y)

        return logits_y, new_y, new_z, halt_p


# ===========================================================
# 3. Recursive loop with deep supervision
# ===========================================================

def recursive_pass(model, maze, label, T=5, use_halting=False):
    B = maze.size(0)
    y = torch.zeros_like(maze)
    z = torch.zeros((B,16,5,5), device=maze.device)

    total_loss = 0.0
    criterion = nn.BCEWithLogitsLoss()

    halting_mass = torch.zeros((B,1,5,5), device=maze.device)
    final_output = None

    for t in range(T):
        logits_y, new_y, new_z, halt_p = model(maze, y, z)

        loss_t = criterion(logits_y, label)
        total_loss += loss_t

        if use_halting:
            still = (halting_mass < 1.0).float()
            halting_mass += halt_p * still
            final_output = logits_y
            if (halting_mass >= 1.0).all():
                return total_loss/(t+1), final_output

        y, z = new_y, new_z
        final_output = logits_y

    return total_loss/T, final_output


# ===========================================================
# 4. TRAINING
# ===========================================================

train_ds = HardMazeDataset(n=2000, min_path_len=7)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)

model = TRM(hidden_dim=64, halting=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 10
T = 5

for epoch in range(EPOCHS):
    model.train()
    total = 0
    for maze, label in train_loader:
        maze = maze.to(device)
        label = label.to(device)

        loss, _ = recursive_pass(model, maze, label, T=T, use_halting=True)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += loss.item() * maze.size(0)

    print(f"Epoch {epoch+1}, loss = {total / len(train_loader.dataset):.4f}")


# ===========================================================
# 5. INFERENCE
# ===========================================================

def infer(model, maze, T=10, halting=True):
    model.eval()
    with torch.no_grad():
        B = 1
        y = torch.zeros_like(maze)
        z = torch.zeros((B,16,5,5), device=maze.device)

        halting_mass = torch.zeros_like(maze)
        final_logits = None

        for t in range(T):
            logits_y, new_y, new_z, halt_p = model(maze, y, z)
            y, z = new_y, new_z
            final_logits = logits_y

            if halting:
                still = (halting_mass < 1).float()
                halting_mass += halt_p * still
                if (halting_mass >= 1).all():
                    break

        probs = torch.sigmoid(final_logits).cpu().numpy()[0,0]
        return (probs > 0.5).astype(int), probs


# ===========================================================
# 6. DEMO ON ONE HARD MAZE
# ===========================================================

maze, label = generate_hard_maze_5x5(min_path_len=7)
maze_t = torch.tensor(maze[None,None,:,:], device=device)

pred_map, prob = infer(model, maze_t)

print("Maze input:")
print(maze)

print("\nGround truth path:")
print(label)

print("\nPredicted path:")
print(pred_map)


Epoch 1, loss = 0.5573
Epoch 2, loss = 0.3403
Epoch 3, loss = 0.2674
Epoch 4, loss = 0.2295
Epoch 5, loss = 0.2005
Epoch 6, loss = 0.1812
Epoch 7, loss = 0.1667
Epoch 8, loss = 0.1564
Epoch 9, loss = 0.1518
Epoch 10, loss = 0.1385
Maze input:
[1. 1. 0. 1. 1.]

Ground truth path:
[[0. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 1. 0. 1. 1.]
 [0. 1. 1. 1. 0.]
 [0. 0. 0. 0. 0.]]

Predicted path:
[[0 0 0 0 0]
 [0 1 1 0 0]
 [0 1 0 1 1]
 [0 1 1 1 0]
 [0 0 0 0 0]]


In [6]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params}")

Total number of parameters in the model: 48530


In [5]:

print("Maze input:")
print(maze)

print("\nGround truth path:")
print(label)

print("\nPredicted path:")
print(pred_map)

Maze input:
[[1. 1. 0. 1. 1.]
 [0. 0. 2. 1. 0.]
 [0. 0. 1. 0. 3.]
 [0. 0. 0. 0. 1.]
 [1. 1. 0. 0. 1.]]

Ground truth path:
[[0. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 1. 0. 1. 1.]
 [0. 1. 1. 1. 0.]
 [0. 0. 0. 0. 0.]]

Predicted path:
[[0 0 0 0 0]
 [0 1 1 0 0]
 [0 1 0 1 1]
 [0 1 1 1 0]
 [0 0 0 0 0]]


# Task
Review and present the 1st, 10th, and a newly generated unseen maze, along with their ground truth solutions.

## Print First Training Maze

### Subtask:
Access the first maze and its corresponding solution from the `train_ds` dataset and print them.


**Reasoning**:
To fulfill the subtask, I need to access the first maze and its ground truth label from the `train_ds` dataset, convert them to NumPy arrays, and then print them with descriptive headers as requested in the instructions.



In [2]:
first_maze_tensor, first_label_tensor = train_ds[0]

# Convert to numpy arrays and remove the channel dimension (index 0)
first_maze = first_maze_tensor.squeeze().numpy()
first_label = first_label_tensor.squeeze().numpy()

print("First Training Maze Input:")
print(first_maze)

print("\nFirst Training Maze Ground Truth Path:")
print(first_label)

First Training Maze Input:
[[0. 1. 1. 0. 1.]
 [1. 0. 0. 0. 1.]
 [1. 0. 1. 1. 3.]
 [0. 0. 0. 0. 0.]
 [2. 0. 0. 0. 1.]]

First Training Maze Ground Truth Path:
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [1. 1. 1. 1. 1.]
 [1. 0. 0. 0. 0.]]


**Reasoning**:
Following the previous step of printing the first maze, the next logical step is to print the 10th maze and its ground truth path from the `train_ds` dataset, as requested by the main task.



In [3]:
tenth_maze_tensor, tenth_label_tensor = train_ds[9]

# Convert to numpy arrays and remove the channel dimension (index 0)
tenth_maze = tenth_maze_tensor.squeeze().numpy()
tenth_label = tenth_label_tensor.squeeze().numpy()

print("10th Training Maze Input:")
print(tenth_maze)

print("\n10th Training Maze Ground Truth Path:")
print(tenth_label)

10th Training Maze Input:
[[1. 0. 1. 1. 0.]
 [1. 0. 1. 1. 1.]
 [1. 0. 0. 0. 2.]
 [1. 0. 1. 1. 1.]
 [3. 0. 1. 1. 1.]]

10th Training Maze Ground Truth Path:
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 1. 1. 1. 1.]
 [0. 1. 0. 0. 0.]
 [1. 1. 0. 0. 0.]]


**Reasoning**:
Having presented the 1st and 10th training mazes, the next step, as per the main task, is to generate and present a new, unseen maze along with its ground truth solution. The `generate_hard_maze_5x5` function can be used for this purpose.



In [4]:
unseen_maze, unseen_label = generate_hard_maze_5x5(min_path_len=7)

print("Newly Generated Unseen Maze Input:")
print(unseen_maze)

print("\nNewly Generated Unseen Maze Ground Truth Path:")
print(unseen_label)

Newly Generated Unseen Maze Input:
[[0. 0. 1. 1. 1.]
 [0. 0. 0. 3. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 1. 2. 1.]
 [0. 0. 0. 0. 0.]]

Newly Generated Unseen Maze Ground Truth Path:
[[0. 0. 0. 0. 0.]
 [0. 1. 1. 1. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 1. 0.]
 [0. 1. 1. 1. 0.]]


## Final Task

### Subtask:
Review and present the 1st, 10th, and the newly generated unseen maze along with their ground truth solutions.


## Summary:

### Data Analysis Key Findings
*   The first training maze and its ground truth path were successfully accessed, converted to NumPy arrays, and printed.
*   The tenth training maze and its ground truth path were successfully accessed, converted to NumPy arrays, and printed.
*   A newly generated 5x5 unseen maze and its corresponding ground truth path were successfully created and displayed.

### Insights or Next Steps
*   The ability to present existing training data and newly generated, unseen data with their solutions is critical for understanding the dataset and for evaluating model performance on diverse maze configurations.
