<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/TRM_Demonstration_patched.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Understanding Tiny Recursive Models (TRM)

This Colab notebook provides a hands-on implementation of a Tiny Recursive Model (TRM), based on the concepts outlined in the [learn-tiny-recursive-models GitHub repository](https://github.com/vukrosic/learn-tiny-recursive-models).

We will:
1.  Implement the core `RecursiveBlock`.
2.  Build the full `TRM` model and enable it to run on a GPU.
3.  Run a forward pass to see how it processes a sequence.
4.  Set up a simple training loop to watch the model learn.
5.  Train the model on a more complex task: solving a dataset of **100** 5x5 mazes.
6.  Test the model's generalization on a new, unseen maze and **verify the solution path**.


## Setup

First, let's install PyTorch and set up our device to use a GPU if one is available in the Colab environment. Using a GPU will significantly speed up training.


In [None]:
!pip install -q torch
import torch
import torch.nn as nn
import torch.optim as optim
import random

# Set the device to a GPU if available, otherwise use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


## 1. The Core Component: `RecursiveBlock`

The fundamental building block of a TRM is the `RecursiveBlock`. It takes an input tensor and a hidden state from the previous step, and produces an output and an updated hidden state. This recursive nature allows it to process sequences step-by-step.

The block consists of:
-   Two Layer Normalization layers for stability.
-   Two Linear layers to transform the concatenated input and state.
-   A GELU activation function for non-linearity.


In [None]:
class RecursiveBlock(nn.Module):
    """
    A single recursive block for the TRM.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.norm_state = nn.LayerNorm(d_model)
        self.norm_input = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(2 * d_model, 2 * d_model)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(2 * d_model, 2 * d_model)

    def forward(self, x, state):
        """
        Forward pass for the recursive block.
        Args:
            x (torch.Tensor): Input tensor for the current step. Shape: (batch_size, d_model)
            state (torch.Tensor): Hidden state from the previous step. Shape: (batch_size, d_model)
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The output and the new state. Both have shape (batch_size, d_model)
        """
        normalized_state = self.norm_state(state)
        normalized_input = self.norm_input(x)
        combined_input = torch.cat([normalized_state, normalized_input], dim=1)
        hidden = self.linear1(combined_input)
        hidden = self.activation(hidden)
        processed_output = self.linear2(hidden)
        new_state = state + processed_output[:, :self.d_model]
        output = processed_output[:, self.d_model:]
        return output, new_state


## 2. The TRM Model: Processing Sequences

The full TRM model wraps the `RecursiveBlock`. It initializes the hidden state (usually with zeros) and then iterates through the input sequence, feeding each element into the recursive block one at a time. This step-by-step processing is the core of its operation.

We also add input and output embedding layers to map our vocabulary to the model's dimension and back.


In [None]:
class TRM(nn.Module):
    """
    The Tiny Recursive Model.
    """
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.recursive_block = RecursiveBlock(d_model)
        self.norm_out = nn.LayerNorm(d_model)
        self.output_linear = nn.Linear(d_model, vocab_size)

    def forward(self, input_sequence):
        """
        Forward pass for the entire sequence.
        Args:
            input_sequence (torch.Tensor): A sequence of token IDs. Shape: (batch_size, seq_len)
        Returns:
            torch.Tensor: The output logits for each step in the sequence. Shape: (batch_size, seq_len, vocab_size)
        """
        batch_size, seq_len = input_sequence.shape
        embedded_input = self.embedding(input_sequence)

        # Initialize the hidden state on the same device as the input
        state = torch.zeros(batch_size, self.d_model, device=input_sequence.device)

        outputs = []
        for i in range(seq_len):
            step_input = embedded_input[:, i, :]
            output, state = self.recursive_block(step_input, state)
            outputs.append(output)

        outputs_tensor = torch.stack(outputs, dim=1)
        normalized_outputs = self.norm_out(outputs_tensor)
        logits = self.output_linear(normalized_outputs)

        return logits


## 3. Forward Pass Demonstration

Let's create a dummy input sequence and pass it through our model to see the shapes of the tensors at each step. We'll make sure to move both the model and the data to our selected device (GPU or CPU).


In [None]:
VOCAB_SIZE = 20
D_MODEL = 32
SEQ_LEN = 5
BATCH_SIZE = 1

# Instantiate the model and move it to the configured device
model = TRM(vocab_size=VOCAB_SIZE, d_model=D_MODEL).to(device)
print("Model Architecture:")
print(model)

# Create a dummy input sequence and move it to the device
dummy_input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN)).to(device)
print(f"\nDummy Input Shape: {dummy_input.shape}")
print(f"Dummy Input Tensor:\n{dummy_input}")

# Perform a forward pass
output_logits = model(dummy_input)

print(f"\nOutput Logits Shape: {output_logits.shape}")
print("This shape (batch_size, seq_len, vocab_size) is what we expect!")


Model Architecture:
TRM(
  (embedding): Embedding(20, 32)
  (recursive_block): RecursiveBlock(
    (norm_state): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (norm_input): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (linear1): Linear(in_features=64, out_features=64, bias=True)
    (activation): GELU(approximate='none')
    (linear2): Linear(in_features=64, out_features=64, bias=True)
  )
  (norm_out): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (output_linear): Linear(in_features=32, out_features=20, bias=True)
)

Dummy Input Shape: torch.Size([1, 5])
Dummy Input Tensor:
tensor([[11,  5, 18, 12, 12]], device='cuda:0')

Output Logits Shape: torch.Size([1, 5, 20])
This shape (batch_size, seq_len, vocab_size) is what we expect!


## 4. A Simple Training Example

To show that the model can learn, let's create a simple "next token prediction" task. The goal is to train the model to predict the next number in a sequence.

-   **Input:** `[1, 2, 3, 4]`
-   **Target:** `[2, 3, 4, 5]`


In [None]:
TRAINING_VOCAB_SIZE = 10
TRAINING_D_MODEL = 16

# Create a new model instance for training and move to device
training_model = TRM(vocab_size=TRAINING_VOCAB_SIZE, d_model=TRAINING_D_MODEL).to(device)
optimizer = optim.Adam(training_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# Data - move to device
input_data = torch.tensor([[1, 2, 3, 4]]).to(device)
target_data = torch.tensor([[2, 3, 4, 5]]).to(device)

print(f"Input:  {input_data[0].tolist()}")
print(f"Target: {target_data[0].tolist()}")

# Training Loop
epochs = 100
for epoch in range(epochs):
    optimizer.zero_grad()
    logits = training_model(input_data)
    loss = criterion(logits.view(-1, TRAINING_VOCAB_SIZE), target_data.view(-1))
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# Inference after training
print("\n--- Testing after training ---")
with torch.no_grad():
    predictions = training_model(input_data)
    predicted_ids = torch.argmax(predictions, dim=2)

    print(f"Input:              {input_data[0].tolist()}")
    print(f"Predicted sequence: {predicted_ids[0].tolist()}")
    print(f"Target sequence:    {target_data[0].tolist()}")
    print("\nThe model has learned to predict the next number!")


Input:  [1, 2, 3, 4]
Target: [2, 3, 4, 5]
Epoch [10/100], Loss: 0.3124
Epoch [20/100], Loss: 0.0543
Epoch [30/100], Loss: 0.0162
Epoch [40/100], Loss: 0.0081
Epoch [50/100], Loss: 0.0054
Epoch [60/100], Loss: 0.0043
Epoch [70/100], Loss: 0.0036
Epoch [80/100], Loss: 0.0031
Epoch [90/100], Loss: 0.0028
Epoch [100/100], Loss: 0.0025

--- Testing after training ---
Input:              [1, 2, 3, 4]
Predicted sequence: [2, 3, 4, 5]
Target sequence:    [2, 3, 4, 5]

The model has learned to predict the next number!


## 5. Advanced Example: Solving 5x5 Mazes

Now for a more challenging task. Let's train the TRM to solve 5x5 mazes.

**The Task:** We provide the model with a sequence representing the maze structure and the starting position. Its goal is to predict the sequence of coordinates that solves the maze.

**Data:** We will programmatically generate a dataset of **100 random mazes** for training and test its generalization on one **unseen maze**.

**Data Representation:**
- **Vocabulary:**
    - `0`: Wall (`#`), `1`: Path (`.`), `2`: Start (`S`), `3`: End (`E`)
    - `4`: Separator (`|`)
    - `5-29`: Path Coordinates, mapping each cell `(row, col)` to a unique token `5 + row * 5 + col`.
- **Loss Mask:** We only care about predicting the path. We will use a **loss mask** to ignore the model's predictions for the maze layout part of the sequence during training. This focuses the model on the solving task.


In [None]:
from collections import deque

# --- Maze Generation Helper ---
def generate_random_maze(size=5):
    # 0: Wall, 1: Path, 2: Start, 3: End
    while True:
        # Create a grid of all paths (1)
        maze = [[1 for _ in range(size)] for _ in range(size)]

        # Add random walls (0) with ~30% probability
        # Keep (0,0) and (size-1, size-1) clear for Start/End
        for r in range(size):
            for c in range(size):
                if (r == 0 and c == 0) or (r == size-1 and c == size-1):
                    continue
                if random.random() < 0.3:
                    maze[r][c] = 0

        maze[0][0] = 2
        maze[size-1][size-1] = 3

        # BFS to find shortest path
        queue = deque([(0, 0, [])])
        visited = set([(0,0)])
        solution = None

        while queue:
            r, c, path = queue.popleft()
            current_path = path + [(r, c)]

            if r == size-1 and c == size-1:
                solution = current_path
                break

            for dr, dc in [(0,1), (0,-1), (1,0), (-1,0)]:
                nr, nc = r + dr, c + dc
                if 0 <= nr < size and 0 <= nc < size and maze[nr][nc] != 0 and (nr, nc) not in visited:
                    visited.add((nr, nc))
                    queue.append((nr, nc, current_path))

        if solution:
            return {"maze": maze, "path": solution}

# --- Generate Datasets ---
print("Generating dataset of 100 mazes... this may take a moment.")
MAZE_DATASET = []
hashes = set()

# Generate 100 unique training mazes
while len(MAZE_DATASET) < 100:
    data = generate_random_maze()
    # Create a hashable representation (tuple of tuples) to ensure uniqueness
    maze_tuple = tuple(tuple(row) for row in data["maze"])
    if maze_tuple not in hashes:
        hashes.add(maze_tuple)
        MAZE_DATASET.append(data)

print(f"Generated {len(MAZE_DATASET)} training mazes.")

# Generate one unseen maze for testing that isn't in the training set
while True:
    UNSEEN_MAZE = generate_random_maze()
    maze_tuple = tuple(tuple(row) for row in UNSEEN_MAZE["maze"])
    if maze_tuple not in hashes:
        break

print("Generated unseen test maze.")

# --- Vocabulary and Preprocessing ---
WALL, PATH, START, END, SEP = 0, 1, 2, 3, 4
PATH_TOKEN_OFFSET = 5
MAZE_SIZE = 5
MAZE_TOKENS = MAZE_SIZE * MAZE_SIZE

def preprocess_maze_data(dataset):
    sequences = []
    for item in dataset:
        maze_flat = [token for row in item["maze"] for token in row]
        path_tokens = [PATH_TOKEN_OFFSET + r * MAZE_SIZE + c for r, c in item["path"]]
        full_sequence = maze_flat + [SEP] + path_tokens
        sequences.append(torch.tensor(full_sequence))
    return sequences

training_sequences = preprocess_maze_data(MAZE_DATASET)


Generating dataset of 100 mazes... this may take a moment.
Generated 100 training mazes.
Generated unseen test maze.


In [None]:
# --- Maze Model Training (TRM Deep Supervision) ---
print("Starting maze training with TRM-style deep supervision...")

# Wrap training sequences into a simple DataLoader
from torch.utils.data import DataLoader

class MazeDataset(torch.utils.data.Dataset):
    def __init__(self, seqs):
        self.seqs = seqs
    def __len__(self):
        return len(self.seqs)
    def __getitem__(self, idx):
        seq = self.seqs[idx]
        x = seq[:-1]
        y = seq[1:]
        return x, y

maze_dataset = MazeDataset(training_sequences)
maze_loader = DataLoader(maze_dataset, batch_size=1, shuffle=True)

# Use train_trm_deepsup from appended utilities
train_trm_deepsup(
    maze_model,
    maze_optimizer,
    nn.CrossEntropyLoss(),
    maze_loader,
    device,
    epochs=150,
    Nsup=4,
    n_inner=4
)

print("Maze training finished (TRM deep supervision).")


Starting maze training with TRM-style deep supervision...


NameError: name 'train_trm_deepsup' is not defined

In [None]:
# --- Inference on an Unseen Maze ---
def solve_maze(model, maze_grid, device, max_len=30):
    model.eval() # Set model to evaluation mode

    maze_flat = [token for row in maze_grid for token in row]

    # Find start position to begin generation
    start_pos_flat = maze_flat.index(START)
    start_r, start_c = divmod(start_pos_flat, MAZE_SIZE)
    start_path_token = PATH_TOKEN_OFFSET + start_r * MAZE_SIZE + start_c

    # Initial input: maze layout + separator + start token, moved to device
    input_seq = torch.tensor(maze_flat + [SEP] + [start_path_token]).unsqueeze(0).to(device)

    generated_path_tokens = [start_path_token]

    with torch.no_grad():
        for _ in range(max_len - 1):
            logits = model(input_seq)

            # Get the prediction for the very last token in the sequence
            next_token_logits = logits[:, -1, :]
            predicted_token = torch.argmax(next_token_logits, dim=1).item()

            generated_path_tokens.append(predicted_token)
            # Append the prediction and update the input for the next step
            input_seq = torch.cat([input_seq, torch.tensor([[predicted_token]], device=device)], dim=1)

            # Stop if we predict the end token
            end_pos_flat = maze_flat.index(END)
            end_r, end_c = divmod(end_pos_flat, MAZE_SIZE)
            end_path_token = PATH_TOKEN_OFFSET + end_r * MAZE_SIZE + end_c
            if predicted_token == end_path_token:
                break

    # Convert token IDs back to (row, col) coordinates
    path_coords = []
    for token in generated_path_tokens:
        flat_pos = token - PATH_TOKEN_OFFSET
        r, c = divmod(flat_pos, MAZE_SIZE)
        path_coords.append((r, c))

    return path_coords

# --- Verification Logic ---
def verify_path(maze_grid, path):
    print("\nVerifying solution...")

    if not path:
        return False, "Path is empty."

    # Check Start
    start_r, start_c = path[0]
    if maze_grid[start_r][start_c] != START:
        return False, f"Path starts at ({start_r}, {start_c}) but maze start is not there."

    # Check End
    end_r, end_c = path[-1]
    if maze_grid[end_r][end_c] != END:
        return False, f"Path ends at ({end_r}, {end_c}) but maze end is not there."

    # Check Validity of steps
    for i in range(len(path) - 1):
        r1, c1 = path[i]
        r2, c2 = path[i+1]

        # Adjacency
        if abs(r1 - r2) + abs(c1 - c2) != 1:
            return False, f"Invalid jump from ({r1},{c1}) to ({r2},{c2})."

        # Wall collision
        if maze_grid[r2][c2] == WALL:
            return False, f"Step ({r2},{c2}) hits a wall."

    return True, "Path is valid!"

# Test with the unseen maze
unseen_maze_grid = UNSEEN_MAZE["maze"]
predicted_path = solve_maze(maze_model, unseen_maze_grid, device=device)

print("Unseen Maze to solve:")
for row in unseen_maze_grid:
     print("".join([{0:'#', 1:'.', 2:'S', 3:'E'}[c] for c in row]))

print(f"\nPredicted Path (coordinates):\n{predicted_path}")
print(f"\nCorrect Path (coordinates):\n{UNSEEN_MAZE['path']}")

# Run verification
valid, msg = verify_path(unseen_maze_grid, predicted_path)
print(f"Verification Result: {msg}")
if valid:
  print("SUCCESS: The model solved the unseen maze!")
else:
  print("FAILURE: The model failed to solve the maze.")

# Visualize the path
print("\nVisualized solution (* = predicted path):")
for r in range(MAZE_SIZE):
    line = ""
    for c in range(MAZE_SIZE):
        if unseen_maze_grid[r][c] == WALL:
            line += '#'
        elif (r,c) in predicted_path:
            line += '*'
        else:
            line += '.'
    print(line)


NameError: name 'maze_model' is not defined

The model learns a general strategy for solving mazes, which it can then apply to a new maze it has never encountered during training. This demonstrates the TRM's ability to handle structured sequence-to-sequence tasks that require understanding a context (the maze layout) to generate a relevant output (the path).


## 6. Conclusion

This notebook provided a brief, practical introduction to Tiny Recursive Models. We implemented the core components in PyTorch, demonstrated its learning capability on a simple task, and then successfully trained it to solve a complex maze-solving problem, verifying its generalization on unseen data.

TRMs offer an interesting alternative to Transformers, particularly for applications where computational efficiency and explicit state management are important.



## Added: Deep-supervision training utilities (TRM-style)
The following cells implement a training wrapper that performs `Nsup` supervised improvement steps, and an `inner_recursive_updates`
function that runs several inner recursion steps with `torch.no_grad()` for all but the final step (the 1-step gradient approximation).
Use these utilities instead of the notebook's original training loop to get behaviour closer to the TRM paper.


In [None]:
import torch
import torch.nn as nn
from typing import Optional

def inner_recursive_updates(model: nn.Module, z: torch.Tensor, x: Optional[torch.Tensor]=None,
                            n_inner: int = 4, device: Optional[torch.device]=None, **forward_kwargs):
    """Run n_inner recursive updates on latent z using the model's forward-step function.
    Behavior expected from `model`:
      - model should implement a method `step(z, x, **kwargs)` or callable `model(z, x, **kwargs)` that returns new z.
      - If the notebook's model exposes a different API, pass a wrapper into forward_kwargs like step_fn=...
    Strategy:
      - Run (n_inner - 1) steps in torch.no_grad() to avoid computing gradients for intermediate steps.
      - Run the final step normally to allow gradients to flow into that last step (1-step gradient approx).
    Returns final z (with gradients) and the intermediate ungrad z from previous steps (detached).
    """
    step_fn = forward_kwargs.pop('step_fn', None)
    if step_fn is None:
        # try common names
        if hasattr(model, 'step'):
            step_fn = model.step
        else:
            # fallback to calling model as a function: model(z, x, **kwargs)
            step_fn = lambda z_, x_, **kw: model(z_, x_, **kw)
    # run intermediate steps without grad
    z_cur = z
    for _ in range(max(0, n_inner - 1)):
        with torch.no_grad():
            z_cur = step_fn(z_cur, x, **forward_kwargs)
    # final step with grad
    z_final = step_fn(z_cur, x, **forward_kwargs)
    return z_final

def default_answer_head(model: nn.Module, z: torch.Tensor, **kwargs):
    """Try to produce an answer/prediction y from z.
    The notebook likely has a head named `head` or `output_layer`. We'll try to use them,
    otherwise we fall back to a linear layer temporarily created (not trained).
    """
    # Prefer user-provided heads if present
    if hasattr(model, 'head'):
        return model.head(z)
    if hasattr(model, 'output_layer'):
        return model.output_layer(z)
    # fallback: if model has attribute `vocab_size` or `num_classes`, create a linear layer (on-the-fly)
    if hasattr(model, 'vocab_size'):
        out_dim = int(model.vocab_size)
    elif hasattr(model, 'num_classes'):
        out_dim = int(model.num_classes)
    else:
        # can't guess â€” return z directly
        return z
    temp_head = nn.Linear(z.size(-1), out_dim).to(z.device)
    return temp_head(z)

SyntaxError: unexpected character after line continuation character (ipython-input-891387040.py, line 9)

In [None]:
# Deep-supervision training step.
# This function accumulates the supervised loss across Nsup steps (as in the TRM paper).
def train_step_deepsup(model: nn.Module, optimizer: torch.optim.Optimizer, criterion, batch, device: torch.device,
                       Nsup: int = 4, n_inner: int = 4, detach_between_sup: bool = True, step_fn=None, answer_fn=None):
    """Run a single training step with deep supervision.
    - model: your TRM model (must accept z updates via step_fn or model.step)
    - optimizer, criterion: usual objects
    - batch: training batch (caller must know structure: typically inputs, targets)
    - device: torch device
    - Nsup: number of supervised improvement steps (outer loop)
    - n_inner: number of inner recursive updates per supervision step
    - detach_between_sup: whether to .detach() the latent z between supervision steps (as paper suggests)
    - step_fn: optional function(z, x, **kwargs)->z to perform one recursive update. If omitted, inferred.
    - answer_fn: optional function(model, z)->y to produce predictions from z. If omitted, attempts common names.
    Returns loss (scalar tensor).
    """
    model.train()
    optimizer.zero_grad()
    # User must provide how to extract x and target from batch.
    # We'll support common patterns: (x, target) or dict with keys 'x' and 'y' or 'input'/'target'.
    if isinstance(batch, (list, tuple)) and len(batch) >= 2:
        x, target = batch[0], batch[1]
    elif isinstance(batch, dict):
        x = batch.get('x', batch.get('input', None))
        target = batch.get('y', batch.get('target', None))
    else:
        raise ValueError('Unrecognized batch format for train_step_deepsup. Provide (x, target) or dict.')
    x = x.to(device) if isinstance(x, torch.Tensor) else x
    target = target.to(device) if isinstance(target, torch.Tensor) else target

    # initialize latent z; prefer a model-provided initializer
    if hasattr(model, 'init_z'):
        z = model.init_z(x).to(device)
    elif hasattr(model, 'init_state'):
        z = model.init_state(x).to(device)
    else:
        # fallback: use zeros with hidden size guess
        hidden_size = getattr(model, 'hidden_size', getattr(model, 'd_model', None))
        if hidden_size is None:
            # try run a single forward to get a z (not ideal but robust)
            with torch.no_grad():
                try:
                    trial = model(x)
                    # attempt to treat trial as a dict or tuple
                    if isinstance(trial, dict) and 'z' in trial:
                        z = trial['z'].detach().clone().to(device)
                    elif isinstance(trial, (list, tuple)) and len(trial) >= 1:
                        z = trial[0].detach().clone().to(device)
                    else:
                        # as last resort, make zeros with batch dimension
                        bsz = x.size(0) if hasattr(x, 'size') else 1
                        hidden_size = trial.size(-1) if hasattr(trial, 'size') else 128
                        z = torch.zeros(bsz, hidden_size, device=device)
                except Exception as e:
                    # cannot infer; create small zeros
                    z = torch.zeros( (x.size(0) if hasattr(x, 'size') else 1, 128), device=device)

    loss_total = 0.0
    for s in range(Nsup):
        # run inner recursion: many steps but only last contributes to gradient
        z = inner_recursive_updates(model, z, x=x, n_inner=n_inner, device=device, step_fn=step_fn)
        # produce answer y
        if answer_fn is None:
            y_pred = default_answer_head(model, z)
        else:
            y_pred = answer_fn(model, z)
        # compute supervised loss (supporting classification or regression)
        if isinstance(criterion, nn.CrossEntropyLoss):
            # expect target to be LongTensor of class indices
            loss = criterion(y_pred, target.long())
        else:
            # generic loss: assume shapes align
            loss = criterion(y_pred, target)
        loss_total = loss_total + loss
        # detach z between supervision steps to match paper's detach strategy (if requested)
        if detach_between_sup and (s < Nsup - 1):
            z = z.detach()
    # average loss across supervision steps
    loss_avg = loss_total / float(Nsup)
    loss_avg.backward()
    optimizer.step()
    return loss_avg.item()

SyntaxError: unexpected character after line continuation character (ipython-input-2000704924.py, line 5)

In [None]:

# Example training loop wrapper that uses train_step_deepsup.
def train_trm_deepsup(model, optimizer, criterion, dataloader, device, epochs=3,
                      Nsup=4, n_inner=4, detach_between_sup=True, step_fn=None, answer_fn=None, log_every=10):
    for epoch in range(epochs):
        running = 0.0
        for i, batch in enumerate(dataloader):
            loss = train_step_deepsup(model, optimizer, criterion, batch, device,
                                     Nsup=Nsup, n_inner=n_inner, detach_between_sup=detach_between_sup,
                                     step_fn=step_fn, answer_fn=answer_fn)
            running += loss
            if (i + 1) % log_every == 0:
                print(f\"Epoch {epoch+1} | batch {i+1} | avg loss {running/log_every:.4f}\")
                running = 0.0
    print('Finished training (deep-supervision wrapper).')


# Usage notes (show to the user):
print(\"\\n--- TRM deep-supervision utilities loaded. To use them:\\n\"\
      \"1) Call train_trm_deepsup(model, optimizer, criterion, train_loader, device, Nsup=..., n_inner=... )\\n\"\
      \"2) If your model uses a custom step API, pass step_fn=model.step or a wrapper step_fn.\\n\"\
      \"3) If your model provides an answer head, default_answer_head will use model.head or model.output_layer.\\n\" )
