<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/100_mazes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from collections import deque
import random
import torch

# --- Maze Generation Helper ---
def generate_random_maze(size=5):
    # 0: Wall, 1: Path, 2: Start, 3: End
    while True:
        # Create a grid of all paths (1)
        maze = [[1 for _ in range(size)] for _ in range(size)]

        # Add random walls (0) with ~30% probability
        # Keep (0,0) and (size-1, size-1) clear for Start/End
        for r in range(size):
            for c in range(size):
                if (r == 0 and c == 0) or (r == size-1 and c == size-1):
                    continue
                if random.random() < 0.3:
                    maze[r][c] = 0

        maze[0][0] = 2
        maze[size-1][size-1] = 3

        # BFS to find shortest path
        queue = deque([(0, 0, [])])
        visited = set([(0,0)])
        solution = None

        while queue:
            r, c, path = queue.popleft()
            current_path = path + [(r, c)]

            if r == size-1 and c == size-1:
                solution = current_path
                break

            for dr, dc in [(0,1), (0,-1), (1,0), (-1,0)]:
                nr, nc = r + dr, c + dc
                if 0 <= nr < size and 0 <= nc < size and maze[nr][nc] != 0 and (nr, nc) not in visited:
                    visited.add((nr, nc))
                    queue.append((nr, nc, current_path))

        if solution:
            return {"maze": maze, "path": solution}

# --- Generate Datasets ---
print("Generating dataset of 100 mazes... this may take a moment.")
MAZE_DATASET = []
hashes = set()

# Generate 100 unique training mazes
while len(MAZE_DATASET) < 100:
    data = generate_random_maze()
    # Create a hashable representation (tuple of tuples) to ensure uniqueness
    maze_tuple = tuple(tuple(row) for row in data["maze"])
    if maze_tuple not in hashes:
        hashes.add(maze_tuple)
        MAZE_DATASET.append(data)

print(f"Generated {len(MAZE_DATASET)} training mazes.")

# Generate one unseen maze for testing that isn't in the training set
while True:
    UNSEEN_MAZE = generate_random_maze()
    maze_tuple = tuple(tuple(row) for row in UNSEEN_MAZE["maze"])
    if maze_tuple not in hashes:
        break

print("Generated unseen test maze.")

# --- Vocabulary and Preprocessing ---
WALL, PATH, START, END, SEP = 0, 1, 2, 3, 4
PATH_TOKEN_OFFSET = 5
MAZE_SIZE = 5
MAZE_TOKENS = MAZE_SIZE * MAZE_SIZE

def preprocess_maze_data(dataset):
    sequences = []
    for item in dataset:
        maze_flat = [token for row in item["maze"] for token in row]
        path_tokens = [PATH_TOKEN_OFFSET + r * MAZE_SIZE + c for r, c in item["path"]]
        full_sequence = maze_flat + [SEP] + path_tokens
        sequences.append(torch.tensor(full_sequence))
    return sequences

training_sequences = preprocess_maze_data(MAZE_DATASET)

Generating dataset of 100 mazes... this may take a moment.
Generated 100 training mazes.
Generated unseen test maze.


In [None]:
print("Unseen Maze:")
for row in UNSEEN_MAZE['maze']:
    print(row)

print("\nUnseen Maze Path:")
print(UNSEEN_MAZE['path'])

Unseen Maze:
[2, 1, 1, 1, 0]
[0, 1, 1, 1, 0]
[0, 1, 1, 1, 0]
[1, 0, 0, 1, 1]
[0, 1, 0, 1, 3]

Unseen Maze Path:
[(0, 0), (0, 1), (0, 2), (0, 3), (1, 3), (2, 3), (3, 3), (3, 4), (4, 4)]


In [None]:
def verify_maze_path(maze_data):
    maze = maze_data["maze"]
    path = maze_data["path"]
    size = len(maze)

    # Check start and end points
    if path[0] != (0, 0) or maze[0][0] != START:
        print(f"  - Invalid start point: {path[0]} (Expected (0,0) with value {START})")
        return False
    if path[-1] != (size - 1, size - 1) or maze[size-1][size-1] != END:
        print(f"  - Invalid end point: {path[-1]} (Expected ({size-1},{size-1}) with value {END})")
        return False

    # Check each step in the path
    for i in range(len(path) - 1):
        r, c = path[i]
        nr, nc = path[i+1]

        # Check if current cell is a wall (except start/end handled by initial checks)
        if maze[r][c] == WALL and (r,c) != (0,0):
            print(f"  - Path passes through a wall at ({r},{c})")
            return False

        # Check bounds
        if not (0 <= nr < size and 0 <= nc < size):
            print(f"  - Path step out of bounds: ({nr},{nc})")
            return False

        # Check adjacency (only horizontal or vertical moves allowed)
        if not ((abs(nr - r) == 1 and nc == c) or (abs(nc - c) == 1 and nr == r)):
            print(f"  - Invalid move from ({r},{c}) to ({nr},{nc}) (not adjacent or diagonal)")
            return False

        # Check if next cell is a wall (except end cell handled by initial checks)
        if maze[nr][nc] == WALL and (nr,nc) != (size-1,size-1):
            print(f"  - Path attempts to enter a wall at ({nr},{nc})")
            return False

    return True

print("Verifying paths for all mazes in MAZE_DATASET...")
all_paths_valid = True
for idx, maze_data in enumerate(MAZE_DATASET):
    if not verify_maze_path(maze_data):
        print(f"Path verification FAILED for MAZE_DATASET[{idx}]:")
        # Optionally print the problematic maze and path for debugging
        # for row in maze_data['maze']: print(row)
        # print(maze_data['path'])
        all_paths_valid = False
        break # Stop at the first failure for quicker feedback

if all_paths_valid:
    print("All training maze paths are valid!")
else:
    print("Some training maze paths are invalid. Please check the output above.")

print("\nVerifying path for UNSEEN_MAZE...")
if verify_maze_path(UNSEEN_MAZE):
    print("UNSEEN_MAZE path is valid!")
else:
    print("UNSEEN_MAZE path is invalid. Please check the output above.")

Verifying paths for all mazes in MAZE_DATASET...
All training maze paths are valid!

Verifying path for UNSEEN_MAZE...
UNSEEN_MAZE path is valid!
