<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/Extract%2C_adapt%2C_train_TRM_on_new_mazes_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Task
Extract, clean, and adapt the Transformer-based Reinforcement Learning for Maze-solving (TRM) code from the provided notebook `/content/TRM_Demonstration_30_5x5_mazes3 (1).ipynb` to train it on 100 newly prepared mazes.

## Read Notebook Content

### Subtask:
Read the content of the provided Jupyter notebook file '/content/TRM_Demonstration_30_5x5_mazes3 (1).ipynb' to understand its structure and existing TRM code.


**Reasoning**:
To read the content of the Jupyter notebook, I will open the file in read mode and print the first 500 characters to understand its structure.



In [None]:
with open('/content/TRM_Demonstration_30_5x5_mazes3 (1).ipynb', 'r') as f:
    notebook_content = f.read()

print(notebook_content[:500])

{
  "nbformat": 4,
  "nbformat_minor": 5,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.6"
    },
    "colab": {
      "pro


**Reasoning**:
The notebook content has been read as a string. To understand its structure and extract the TRM code, I need to parse the JSON string into a Python dictionary.



In [None]:
import json

notebook_data = json.loads(notebook_content)

print(f"Notebook has {len(notebook_data['cells'])} cells.")
# Display the type of the first few cells to get an idea of the structure
for i, cell in enumerate(notebook_data['cells'][:5]):
    print(f"Cell {i}: Type = {cell['cell_type']}")


Notebook has 30 cells.
Cell 0: Type = markdown
Cell 1: Type = markdown
Cell 2: Type = code
Cell 3: Type = markdown
Cell 4: Type = code


**Reasoning**:
Now that the notebook content is parsed into a Python dictionary, I will iterate through the cells to identify and print the content of the first few code cells. This will help in locating the TRM model definition and related code for extraction.



In [None]:
trm_code_cells = []
print("Content of the first 5 code cells:")
for i, cell in enumerate(notebook_data['cells']):
    if cell['cell_type'] == 'code':
        trm_code_cells.append(cell)
        if len(trm_code_cells) <= 5:
            print(f"\n--- Code Cell {i} ---")
            print(''.join(cell['source']))

Content of the first 5 code cells:

--- Code Cell 2 ---
!pip install -q torch
import torch
import torch.nn as nn
import torch.optim as optim
import random

# Set the device to a GPU if available, otherwise use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


--- Code Cell 4 ---
class RecursiveBlock(nn.Module):
    """
    A single recursive block for the TRM.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.norm_state = nn.LayerNorm(d_model)
        self.norm_input = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(2 * d_model, 2 * d_model)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(2 * d_model, 2 * d_model)

    def forward(self, x, state):
        """
        Forward pass for the recursive block.
        Args:
            x (torch.Tensor): Input tensor for the current step. Shape: (batch_size, d_model)
            state (torch.

**Reasoning**:
I have already displayed the first 5 code cells. To comprehensively understand and extract the TRM code, I will now gather the complete source code for the `RecursiveBlock` and `TRM` classes, and any relevant constants or initialization code, from all identified code cells. This will allow for full extraction and later adaptation.



In [None]:
trm_architecture_code = []
for cell in notebook_data['cells']:
    if cell['cell_type'] == 'code':
        source_code = ''.join(cell['source'])
        # Identify and extract core TRM architecture and related setup
        if "class RecursiveBlock" in source_code or \
           "class TRM" in source_code or \
           "import torch" in source_code or \
           "device = torch.device" in source_code:
            trm_architecture_code.append(source_code)

print("--- Extracted TRM Architecture Code ---")
print('\n'.join(trm_architecture_code[:500])) # Print a portion to verify extraction


--- Extracted TRM Architecture Code ---
!pip install -q torch
import torch
import torch.nn as nn
import torch.optim as optim
import random

# Set the device to a GPU if available, otherwise use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class RecursiveBlock(nn.Module):
    """
    A single recursive block for the TRM.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.norm_state = nn.LayerNorm(d_model)
        self.norm_input = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(2 * d_model, 2 * d_model)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(2 * d_model, 2 * d_model)

    def forward(self, x, state):
        """
        Forward pass for the recursive block.
        Args:
            x (torch.Tensor): Input tensor for the current step. Shape: (batch_size, d_model)
            state (torch.Tensor): Hidden state from the previo

**Reasoning**:
To have the complete TRM architecture and associated setup code in a single executable block for easy transfer and use, I will join the extracted code snippets from `trm_architecture_code` into one string. This final concatenated string represents the full TRM model definition and initial configuration.



In [None]:
full_trm_code = "\n".join(trm_architecture_code)

print("--- Complete TRM Architecture Code Extracted ---")
print(full_trm_code)


--- Complete TRM Architecture Code Extracted ---
!pip install -q torch
import torch
import torch.nn as nn
import torch.optim as optim
import random

# Set the device to a GPU if available, otherwise use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class RecursiveBlock(nn.Module):
    """
    A single recursive block for the TRM.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.norm_state = nn.LayerNorm(d_model)
        self.norm_input = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(2 * d_model, 2 * d_model)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(2 * d_model, 2 * d_model)

    def forward(self, x, state):
        """
        Forward pass for the recursive block.
        Args:
            x (torch.Tensor): Input tensor for the current step. Shape: (batch_size, d_model)
            state (torch.Tensor): Hidden state from t

## Analyze and Extract TRM Code

### Subtask:
Analyze the extracted TRM code string (`full_trm_code`) for redundancies and prepare it for execution and subsequent adaptation. This includes removing shell commands, duplicate imports, and isolating the core model definition from the specific maze data and training logic.


**Reasoning**:
To analyze and adapt the extracted TRM code, I need to clean it first. This involves removing non-Python shell commands like `!pip install`, eliminating duplicate import statements to avoid redundancy, and separating the core model definitions from specific maze data and example training logic. After cleaning, executing this code will make the `RecursiveBlock` and `TRM` classes and other essential components available for further steps in the environment.



In [None]:
cleaned_trm_code_lines = []
seen_imports = set()
skip_section = False

# Split the code into lines for easier processing
code_lines = full_trm_code.split('\n')

for line in code_lines:
    # 1. Remove the !pip install command
    if line.strip().startswith('!pip install'):
        continue

    # Identify start and end of training code and maze data sections to skip
    if 'TRAINING_VOCAB_SIZE' in line or '# --- Maze Data Setup ---' in line:
        skip_section = True

    # Skip specific examples or data definitions
    if skip_section and (
        'MAZE_DATASET' in line or
        'UNSEEN_MAZE' in line or
        'WALL, PATH, START, END, SEP' in line or
        'PATH_TOKEN_OFFSET' in line or
        'MAZE_SIZE' in line or
        'MAZE_TOKENS' in line or
        'def preprocess_maze_data' in line or
        'training_sequences' in line or
        'print(f"Loaded {len(training_sequences)} training mazes.")' in line or
        '# Instantiate the model and move it to the configured device' in line or # Example model instantiation
        '# Create a dummy input sequence' in line or # Example dummy input
        '# Perform a forward pass' in line or # Example forward pass
        'print("Model Architecture:")' in line or # Example print
        'epochs = 100' in line or # Example training loop
        '# Inference after training' in line or
        '# Set the device to a GPU if available, otherwise use the CPU' in line # This line is good, but the following print is part of training examples, I will re-add device print later
    ):
        # Keep skipping until the end of the data/example training logic
        # The complex logic in MAZE_DATASET means we need to skip line by line until a logical break.
        # This is a bit of a heuristic. Let's make sure we only skip actual data/training lines.
        if 'MAZE_DATASET = [' in line:
             # Find the end of the MAZE_DATASET list
            maze_dataset_start_idx = code_lines.index(line)
            maze_dataset_end_idx = maze_dataset_start_idx
            open_brackets = 0
            for i in range(maze_dataset_start_idx, len(code_lines)):
                open_brackets += code_lines[i].count('[')
                open_brackets -= code_lines[i].count(']')
                if open_brackets == 0 and ']' in code_lines[i]:
                    maze_dataset_end_idx = i
                    break
            # Mark lines in MAZE_DATASET range to be skipped
            for i in range(maze_dataset_start_idx, maze_dataset_end_idx + 1):
                if code_lines[i] not in cleaned_trm_code_lines: # Ensure not to re-mark if already skipped
                    code_lines[i] = '#SKIPPED_DATA_LINE#' # Temporarily mark to skip
            skip_section = False # Reset after processing MAZE_DATASET

        if 'UNSEEN_MAZE = {' in line:
            unseen_maze_start_idx = code_lines.index(line)
            unseen_maze_end_idx = unseen_maze_start_idx
            open_braces = 0
            for i in range(unseen_maze_start_idx, len(code_lines)):
                open_braces += code_lines[i].count('{')
                open_braces -= code_lines[i].count('}')
                if open_braces == 0 and '}' in code_lines[i]:
                    unseen_maze_end_idx = i
                    break
            for i in range(unseen_maze_start_idx, unseen_maze_end_idx + 1):
                if code_lines[i] not in cleaned_trm_code_lines:
                    code_lines[i] = '#SKIPPED_DATA_LINE#'
            skip_section = False

        # More robust skip for the initial training loop
        if 'TRAINING_VOCAB_SIZE' in line:
            # Find the end of this example training loop by looking for the next print statement or end of block
            training_loop_start_idx = code_lines.index(line)
            training_loop_end_idx = training_loop_start_idx
            for i in range(training_loop_start_idx, len(code_lines)):
                if 'print("\n--- Testing after training ---")' in code_lines[i] or 'print("The model has learned to predict the next number!")' in code_lines[i]:
                    training_loop_end_idx = i
                    break
            for i in range(training_loop_start_idx, training_loop_end_idx + 1):
                if code_lines[i] not in cleaned_trm_code_lines:
                    code_lines[i] = '#SKIPPED_DATA_LINE#'
            skip_section = False # Reset after skipping this training block

        continue # Skip current line if part of a section to be skipped

    # If we are not skipping a section, process the line
    if not skip_section:
        # 2. Remove duplicate import statements
        if line.strip().startswith('import') or line.strip().startswith('from'):
            if line.strip() not in seen_imports:
                cleaned_trm_code_lines.append(line)
                seen_imports.add(line.strip())
            continue

        # Heuristic to stop skipping if an empty line or class definition follows
        if not line.strip() and 'class' in code_lines[code_lines.index(line) + 1 if code_lines.index(line) + 1 < len(code_lines) else code_lines.index(line)].strip():
            skip_section = False

        # Only append lines that are not part of data/example training and not marked to be skipped
        if '#SKIPPED_DATA_LINE#' not in line:
            cleaned_trm_code_lines.append(line)

# Join the cleaned lines back into a single string
cleaned_trm_code = "\n".join(cleaned_trm_code_lines)

# Re-insert the device print statement as it's useful for setup
if 'device = torch.device' in cleaned_trm_code and 'print(f"Using device: {device}")' not in cleaned_trm_code:
    # Find the line where device is defined and insert print after it
    device_line_idx = -1
    for i, line in enumerate(cleaned_trm_code_lines):
        if 'device = torch.device' in line:
            device_line_idx = i
            break
    if device_line_idx != -1:
        cleaned_trm_code_lines.insert(device_line_idx + 1, 'print(f"Using device: {device}")')
        cleaned_trm_code = "\n".join(cleaned_trm_code_lines)

print("--- Cleaned TRM Code for Execution ---")
print(cleaned_trm_code)

# 4. Execute the cleaned code
exec(cleaned_trm_code)

print("TRM classes and related setup defined successfully.")












--- Cleaned TRM Code for Execution ---
import torch
import torch.nn as nn
import torch.optim as optim
import random

# Set the device to a GPU if available, otherwise use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class RecursiveBlock(nn.Module):
    """
    A single recursive block for the TRM.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.norm_state = nn.LayerNorm(d_model)
        self.norm_input = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(2 * d_model, 2 * d_model)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(2 * d_model, 2 * d_model)

    def forward(self, x, state):
        """
        Forward pass for the recursive block.
        Args:
            x (torch.Tensor): Input tensor for the current step. Shape: (batch_size, d_model)
            state (torch.Tensor): Hidden state from the previous step. Shape: (batch_

NameError: name 'MAZE_DATASET' is not defined

**Reasoning**:
The previous attempt to clean the code resulted in a `NameError` because parts of the data and training setup were inconsistently removed, leaving a reference to `MAZE_DATASET` which was undefined. To fix this, I will refine the code cleaning logic to explicitly identify and include only the necessary imports, device setup, and the `RecursiveBlock` and `TRM` class definitions, completely excluding all data-specific definitions and example training code.



In [None]:
import_device_block_start = 0
import_device_block_end = 0
recursive_block_start = 0
recursive_block_end = 0
trm_class_start = 0
trm_class_end = 0

lines = full_trm_code.split('\n')

# Find relevant sections
for i, line in enumerate(lines):
    if "import torch" in line and import_device_block_start == 0:
        import_device_block_start = i
    if "device = torch.device" in line:
        # This should capture the end of the initial imports/device setup block
        # We can stop here, as any further imports will be handled by seen_imports
        # and the relevant class definitions are picked up by specific start/end points.
        import_device_block_end = i
    if "class RecursiveBlock" in line:
        recursive_block_start = i
    # Heuristic to find the end of RecursiveBlock class
    if recursive_block_start != 0 and recursive_block_end == 0 and "return output, new_state" in line:
        recursive_block_end = i
    if "class TRM" in line:
        trm_class_start = i
    # Heuristic to find the end of TRM class
    if trm_class_start != 0 and trm_class_end == 0 and "return logits" in line:
        trm_class_end = i

cleaned_code_parts = []
seen_imports = set()

# Add imports and device from the initial block
for i in range(import_device_block_start, import_device_block_end + 1):
    line = lines[i]
    stripped_line = line.strip()
    if stripped_line.startswith('!pip install'):
        continue
    if (stripped_line.startswith('import') or stripped_line.startswith('from')):
        if stripped_line not in seen_imports:
            cleaned_code_parts.append(line)
            seen_imports.add(stripped_line)
    elif "device = torch.device" in line:
        cleaned_code_parts.append(line)
        # Ensure the device print statement is added right after device definition
        cleaned_code_parts.append('print(f"Using device: {device}")')

# Add RecursiveBlock class
if recursive_block_start != 0 and recursive_block_end != 0:
    # Add a newline for separation if necessary
    if cleaned_code_parts and cleaned_code_parts[-1].strip() != '':
        cleaned_code_parts.append('')
    for i in range(recursive_block_start, recursive_block_end + 1):
        cleaned_code_parts.append(lines[i])

# Add TRM class
if trm_class_start != 0 and trm_class_end != 0:
    # Add a newline for separation if necessary
    if cleaned_code_parts and cleaned_code_parts[-1].strip() != '':
        cleaned_code_parts.append('')
    for i in range(trm_class_start, trm_class_end + 1):
        cleaned_code_parts.append(lines[i])

final_cleaned_trm_code = "\n".join(cleaned_code_parts)

print("--- Cleaned TRM Code for Execution ---")
print(final_cleaned_trm_code)
exec(final_cleaned_trm_code)
print("TRM classes and related setup defined successfully.")

--- Cleaned TRM Code for Execution ---
import torch
import torch.nn as nn
import torch.optim as optim
import random
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class RecursiveBlock(nn.Module):
    """
    A single recursive block for the TRM.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.norm_state = nn.LayerNorm(d_model)
        self.norm_input = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(2 * d_model, 2 * d_model)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(2 * d_model, 2 * d_model)

    def forward(self, x, state):
        """
        Forward pass for the recursive block.
        Args:
            x (torch.Tensor): Input tensor for the current step. Shape: (batch_size, d_model)
            state (torch.Tensor): Hidden state from the previous step. Shape: (batch_size, d_model)
        Returns:
            Tuple[torch.Tensor,

## Prepare 100 Mazes

### Subtask:
Generate 100 5x5 mazes with a solvable path and preprocess them into the format required by the TRM model.


**Reasoning**:
To prepare for generating new mazes, I need to first extract the maze-related constants (WALL, PATH, START, END, SEP, PATH_TOKEN_OFFSET, MAZE_SIZE, MAZE_TOKENS) and the `preprocess_maze_data` function from the `full_trm_code` string and execute them. This will ensure these essential components are available in the current environment for maze generation and data preprocessing.



In [None]:
import re

maze_constants_and_preprocessing_code = []

# Regex to capture constants
constant_pattern = re.compile(r'^(WALL|PATH|START|END|SEP|PATH_TOKEN_OFFSET|MAZE_SIZE|MAZE_TOKENS) = .*')
# Regex to capture the preprocess_maze_data function
function_pattern = re.compile(r'^def preprocess_maze_data\(dataset\):.*')

current_function_block = []
in_function = False

for line in full_trm_code.split('\n'):
    if constant_pattern.match(line):
        maze_constants_and_preprocessing_code.append(line)
    elif function_pattern.match(line):
        in_function = True
        current_function_block.append(line)
    elif in_function:
        current_function_block.append(line)
        if not line.strip() and len(current_function_block) > 1: # End of function block
            maze_constants_and_preprocessing_code.extend(current_function_block)
            current_function_block = []
            in_function = False
        elif line.strip() and line[0].isalpha() and not line.strip().startswith('return'): # Heuristic: if a line without indent or a class/def starts, function ended
            maze_constants_and_preprocessing_code.extend(current_function_block[:-1]) # Don't include the breaking line
            current_function_block = []
            in_function = False
            # Re-evaluate the breaking line for constants if it matches
            if constant_pattern.match(line):
                maze_constants_and_preprocessing_code.append(line)

# Add any remaining function lines if the file ends within a function block
if current_function_block:
    maze_constants_and_preprocessing_code.extend(current_function_block)


# Ensure the constants are defined before the function that might use them implicitly
# Re-order if needed, although simple append should mostly work if constants are defined first in source
extracted_code = "\n".join(maze_constants_and_preprocessing_code)

print("--- Extracted Maze Constants and Preprocessing Function ---")
print(extracted_code)

# Execute the extracted code
exec(extracted_code)

print("Maze constants and preprocess_maze_data function defined successfully.")

--- Extracted Maze Constants and Preprocessing Function ---
PATH_TOKEN_OFFSET = 5
MAZE_SIZE = 5
MAZE_TOKENS = MAZE_SIZE * MAZE_SIZE
def preprocess_maze_data(dataset):
    sequences = []
    for item in dataset:
        maze_flat = [token for row in item["maze"] for token in row]
        path_tokens = [PATH_TOKEN_OFFSET + r * MAZE_SIZE + c for r, c in item["path"]]
        full_sequence = maze_flat + [SEP] + path_tokens
        sequences.append(torch.tensor(full_sequence))
    return sequences

Maze constants and preprocess_maze_data function defined successfully.


**Reasoning**:
To generate 100 solvable mazes, I need to define a Python function that uses a maze generation algorithm (like Randomized DFS) to create a maze grid and then finds a solvable path within it using a search algorithm (like BFS). This function will ensure that the generated mazes always have a start at (0,0) and an end at (MAZE_SIZE-1, MAZE_SIZE-1), and return both the maze grid and a valid path.



In [None]:
import random
from collections import deque

def generate_solvable_maze_and_path(size):
    maze = [[WALL for _ in range(size)] for _ in range(size)]

    # Use Randomized DFS for maze generation
    # Stack for DFS
    stack = [(0, 0)]
    visited = set()

    # Mark start and end points for generation logic
    start_pos = (0, 0)
    end_pos = (size - 1, size - 1)

    # Directions for movement (row_offset, col_offset)
    directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]

    path_carved = []

    while stack:
        current_r, current_c = stack[-1]

        if (current_r, current_c) not in visited:
            visited.add((current_r, current_c))
            maze[current_r][current_c] = PATH
            path_carved.append((current_r, current_c))

        neighbors = []
        for dr, dc in directions:
            nr, nc = current_r + dr, current_c + dc
            if 0 <= nr < size and 0 <= nc < size and (nr, nc) not in visited:
                neighbors.append((nr, nc))

        if neighbors:
            next_r, next_c = random.choice(neighbors)
            stack.append((next_r, next_c))
        else:
            stack.pop() # Backtrack

    # Ensure start and end are marked
    maze[start_pos[0]][start_pos[1]] = START
    maze[end_pos[0]][end_pos[1]] = END

    # Find the shortest path using BFS
    queue = deque([(start_pos, [start_pos])])
    bfs_visited = set([start_pos])

    while queue:
        (r, c), current_path = queue.popleft()

        if (r, c) == end_pos:
            return {"maze": maze, "path": current_path}

        for dr, dc in directions:
            nr, nc = r + dr, c + dc

            if 0 <= nr < size and 0 <= nc < size and (nr, nc) not in bfs_visited and maze[nr][nc] != WALL:
                bfs_visited.add((nr, nc))
                queue.append(((nr, nc), current_path + [(nr, nc)]))

    # This should ideally not be reached for a solvable maze generated by DFS
    return {"maze": maze, "path": []} # Return an empty path if no path found (shouldn't happen)

print("Maze generation function 'generate_solvable_maze_and_path' defined.")