In [1]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.55.4
!pip install --no-deps trl==0.22.2

!pip install wandb

In [10]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from datasets import load_dataset
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

# --- MEMORY OPTIMIZATIONS ---
D_MODEL = 512
N_LAYERS = 2
N_SUP = 16
N_RECURSION = 6
T_RECURSION = 3
BATCH_SIZE = 16  # HALVED from 32
LEARNING_RATE = 2e-5
DIM_FEEDFORWARD = 2048 # REDUCED from 2048 default
NUM_EPOCHS = 20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Batch Size set to: {BATCH_SIZE}")
print(f"Transformer Feed-Forward Dim set to: {DIM_FEEDFORWARD}")

Using device: cuda
Batch Size set to: 16
Transformer Feed-Forward Dim set to: 2048


Batch 2: Model Definition (Tiny Recursive Model)
Here, we define the architecture for our TinyRecursiveModel (TRM). According to the paper, this is a simple 2-layer Transformer-style network. It processes a concatenation of the input x, the current prediction y, and the latent reasoning state z. We'll create a standard Transformer encoder block for this.

In [11]:
class TinyRecursiveModel(nn.Module):
    """
    Implements the Tiny Recursive Model (TRM) from the paper.
    This is a small Transformer-based network with 2 layers.
    """
    def __init__(self, d_model=D_MODEL, n_layers=N_LAYERS, nhead=8, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # The paper's model takes 3 inputs (x, y, z). We'll concatenate them.
        # An input projection layer to map the concatenated input to the model dimension.
        self.input_proj = nn.Linear(d_model * 3, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Output heads
        # This head refines the latent state 'z'
        self.latent_head = nn.Linear(d_model, d_model)
        # This head refines the answer 'y'
        self.answer_head = nn.Linear(d_model, d_model)

    def forward(self, x, y, z):
        """
        x: The embedded input question/puzzle. Shape: (batch, seq_len, d_model)
        y: The current embedded prediction. Shape: (batch, seq_len, d_model)
        z: The current latent reasoning state. Shape: (batch, seq_len, d_model)
        """
        # Concatenate inputs along the feature dimension
        combined_input = torch.cat((x, y, z), dim=-1)
        
        # Project the combined input to the model's dimension
        projected_input = self.input_proj(combined_input)

        # Pass through the Transformer encoder
        transformer_output = self.transformer_encoder(projected_input)

        # As per the paper's logic, the model can update both y and z.
        # We'll create two separate outputs from the transformer result.
        
        # The logic is: given x, y, z -> produce a new z'
        # And given y, z -> produce a new y'
        
        # For simplicity and following Figure 1, we will have two heads.
        # The main output of the transformer will be used to update the latent 'z'
        new_z = self.latent_head(transformer_output)

        # The paper suggests the answer update step uses the (new) latent z and old y.
        # We will model this by feeding the transformer output and old 'y' to the answer head.
        new_y = self.answer_head(transformer_output + y) # Using a residual connection for stability

        return new_y, new_z

Batch 3: Defining the Recursive Logic and Training Functions
This is the core logic from the paper's pseudocode. We'll implement latent_recursion for the inner loop and deep_recursion for the outer loop, which cleverly uses torch.no_grad() to manage memory and mimic the paper's training strategy.

In [12]:
def latent_recursion(model, x, y, z, n=N_RECURSION):
    """
    Performs the inner 'latent reasoning' loop.
    'n' recursive steps to refine the latent state 'z'.
    """
    for _ in range(n):
        # In this simplified model, both y and z are returned.
        # The paper's core idea is that the latent state 'z' is what's primarily refined here.
        _, z = model(x, y, z)
    
    # After refining z, one final step to refine y
    y, z = model(x, y, z)
    
    return y, z

def deep_recursion(model, x, y, z, n=N_RECURSION, T=T_RECURSION):
    """
    Performs one step of the outer 'deep supervision' loop.
    This involves T-1 steps with no gradients and one final step with gradients.
    """
    # Recurse for T-1 steps without tracking gradients to save memory
    with torch.no_grad():
        for _ in range(T - 1):
            y, z = latent_recursion(model, x, y, z, n)
            
    # Perform the final recursion step with gradients enabled
    y, z = latent_recursion(model, x, y, z, n)
    
    return y, z

# We also need an output head to convert our embedding back to vocabulary tokens
class OutputHead(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.linear = nn.Linear(d_model, vocab_size)

    def forward(self, y_embedding):
        return self.linear(y_embedding)

# And a Q-head for the Adaptive Computation Time (ACT) halting mechanism
class QHead(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear = nn.Linear(d_model, 1)

    def forward(self, y_embedding):
        # We just need a single scalar value for the halt probability
        # We take the mean over the sequence length dimension
        return self.linear(y_embedding.mean(dim=1))

Batch 4: Data Preparation
In this batch, we will:
Load a Sudoku dataset from the Hugging Face Hub.
Create a simple tokenizer to convert the puzzle strings into numerical tokens.
Set up a custom PyTorch Dataset to handle the tokenization.
Create DataLoaders to feed the data to our model in batches.

In [13]:
# --- DATA LOADING ---
dataset_name = "sapientinc/sudoku-extreme"
dataset = load_dataset(dataset_name)

# Define Vocabulary
VOCAB = ['.'] + [str(i) for i in range(1, 10)] # The dataset uses '.' for empty cells
VOCAB_SIZE = len(VOCAB)
token_to_id = {token: i for i, token in enumerate(VOCAB)}
id_to_token = {i: token for i, token in enumerate(VOCAB)}
SEQ_LEN = 81

# Let's inspect the actual structure of the loaded dataset to be sure.
print("Dataset Features:", dataset['train'].features)
print("\nFirst Training Example:", dataset['train'][0])

Dataset Features: {'source': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'answer': Value(dtype='string', id=None), 'rating': Value(dtype='int64', id=None)}

First Training Example: {'source': 'puzzles4_forum_hardest_1905', 'question': '5...27..9..41......1..5.3...92.6.8...5......66..7..29.8...7...2.......8...9..36..', 'answer': '583427169974136528216859374792364851351298746648715293865971432137642985429583617', 'rating': 18}


In [14]:
from torch.utils.data import Dataset

class SudokuDataset(Dataset):
    def __init__(self, data, token_map):
        self.data = data
        self.token_map = token_map

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        # --- FINAL CORRECTION ---
        # Using the exact column names you found: 'question' and 'answer'
        # Also, the dataset uses '.' for empty cells, not '0'.
        quiz_str = item['question']
        solution_str = item['answer']
        
        # Convert strings to lists of integer token IDs
        quiz_tokens = torch.tensor([self.token_map[char] for char in quiz_str], dtype=torch.long)
        solution_tokens = torch.tensor([self.token_map[char] for char in solution_str], dtype=torch.long)
        
        return quiz_tokens, solution_tokens

# Create the datasets.
# The paper trains on only 1K samples, so let's use .select() for efficiency.
train_data = SudokuDataset(dataset['train'].select(range(1000)), token_to_id)
test_data = SudokuDataset(dataset['test'].select(range(1000)), token_to_id) # Using a 1k slice of test for quick eval

# Create the DataLoaders
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

# Let's inspect a single batch to confirm the shape
try:
    sample_quiz, sample_solution = next(iter(train_dataloader))
    print(f"\nQuiz batch shape: {sample_quiz.shape}")
    print(f"Solution batch shape: {sample_solution.shape}")
except Exception as e:
    print(f"\nFailed to create a batch. Error: {e}")


Quiz batch shape: torch.Size([16, 81])
Solution batch shape: torch.Size([16, 81])


Batch 5: Initializing Models and Optimizer
Now we have our model architecture and our data loaders. The final step before training is to instantiate all the necessary components:
Embedding Layer: A layer to convert our numerical tokens into dense vectors (the d_model dimension).
TRM Model: Our main recursive network.
Output Head: To convert the model's output embeddings back into token probabilities.
Q-Head: For the halting mechanism.
Optimizer: To update the model's weights during training.
Loss Function: To measure the difference between the model's predictions and the true solutions.

In [15]:
# ==============================================================================
# : Login to Hugging Face and Weights & Biases
# ==============================================================================
import wandb
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient

# --- Hugging Face Login ---
print("--- Attempting Hugging Face Login ---")
try:
    user_secrets = UserSecretsClient()
    hf_token = user_secrets.get_secret("HUGGINGFACE_API_KEY")
    login(token=hf_token)
    print("✅ Successfully logged into Hugging Face.")
except Exception as e:
    print("Could not log into Hugging Face. Please ensure the 'HUGGINGFACE_API_KEY' secret is set.")
    print(f"Error: {e}")

# --- Weights & Biases Login ---
print("\n--- Attempting Weights & Biases Login ---")
try:
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("wandb_api_key")
    wandb.login(key=wandb_api_key)
    print("✅ Successfully logged into Weights & Biases.")
    
    # --- Initialize W&B Run ---
    # This should happen right after a successful login
    run = wandb.init(
        project="tiny-recursive-model-sudoku-v1",
        config={
            "learning_rate": LEARNING_RATE,
            "epochs": NUM_EPOCHS,
            "batch_size": BATCH_SIZE,
            "d_model": D_MODEL,
            "n_layers": N_LAYERS,
            "dim_feedforward": DIM_FEEDFORWARD,
            "n_recursion": N_RECURSION,
            "t_recursion": T_RECURSION,
        },
    )
    print("✅ W&B run initialized successfully.")

except Exception as e:
    print("Could not log into W&B or initialize run. Please ensure the 'wandb_api_key' secret is set.")
    print(f"Error: {e}")

--- Attempting Hugging Face Login ---
✅ Successfully logged into Hugging Face.

--- Attempting Weights & Biases Login ---


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


✅ Successfully logged into Weights & Biases.


✅ W&B run initialized successfully.


In [16]:
# 1. Token Embedding Layer
embedding_layer = nn.Embedding(VOCAB_SIZE, D_MODEL).to(device)

# 2. The main Tiny Recursive Model
# --- MEMORY OPTIMIZATION ---
# Pass our smaller dim_feedforward to the model constructor
trm_model = TinyRecursiveModel(
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    dim_feedforward=DIM_FEEDFORWARD
).to(device)

# 3. The Output Head
output_head = OutputHead(D_MODEL, VOCAB_SIZE).to(device)

# 4. The Q-Head
q_head = QHead(D_MODEL).to(device)

# Combine all model parameters for the optimizer
all_params = (
    list(embedding_layer.parameters()) +
    list(trm_model.parameters()) +
    list(output_head.parameters()) +
    list(q_head.parameters())
)

# 5. Optimizer
optimizer = AdamW(all_params, lr=LEARNING_RATE)

# 6. Loss Functions
prediction_loss_fn = nn.CrossEntropyLoss()
act_loss_fn = nn.BCEWithLogitsLoss()

Batch 6: The Training Loop
This is where we put everything together. The loop will iterate through our training data for a few epochs. Inside the loop, we implement the full "Deep Supervision" logic as described in the paper's pseudocode.
Initialize States: For each puzzle, we start with initial y and z embeddings. We'll use simple zero tensors for this.
Deep Supervision Loop: We loop for N_SUP steps.
Forward Pass: In each step, we call our deep_recursion function.
Calculate Losses:
Calculate the prediction loss between the model's output and the true solution.
Calculate the ACT (halting) loss. The paper's pseudocode suggests the target is 1 if the prediction is correct and 0 otherwise.
Backpropagation: We compute the gradients and update the model weights.
Detach: Crucially, we detach the y and z states from the computation graph before the next supervision step. This is the key to managing memory.
ACT Halting: We check the halt condition from the Q-head to potentially break the loop early, saving computation.

In [18]:
# --- DEEPER TRAINING SETUP with W&B ---
# NUM_EPOCHS = 2

# We need to tell W&B to watch our model for gradient tracking (optional but good practice)
wandb.watch(trm_model, log="all", log_freq=100)

for epoch in range(NUM_EPOCHS):
    print(f"--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    
    # --- TRAINING PHASE ---
    trm_model.train() # Set model to training mode
    epoch_loss = 0.0
    progress_bar = tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}")
    
    for i, (quiz_tokens, solution_tokens) in enumerate(progress_bar):
        quiz_tokens = quiz_tokens.to(device)
        solution_tokens = solution_tokens.to(device)
        
        optimizer.zero_grad()
        
        x_embedded = embedding_layer(quiz_tokens)
        y_embedded = torch.zeros_like(x_embedded, device=device)
        z_embedded = torch.zeros_like(x_embedded, device=device)

        total_loss = 0
        actual_steps = 0

        for step in range(N_SUP):
            actual_steps += 1
            y_refined, z_refined = deep_recursion(trm_model, x_embedded, y_embedded, z_embedded)
            
            y_logits = output_head(y_refined)
            q_value = q_head(y_refined)

            prediction_loss = prediction_loss_fn(y_logits.view(-1, VOCAB_SIZE), solution_tokens.view(-1))

            with torch.no_grad():
                y_pred_tokens = torch.argmax(y_logits, dim=-1)
                is_correct = torch.all(y_pred_tokens == solution_tokens, dim=1).float()
            
            halt_loss = act_loss_fn(q_value.squeeze(), is_correct)
            step_loss = prediction_loss + halt_loss
            total_loss += step_loss

            y_embedded = y_refined.detach()
            z_embedded = z_refined.detach()

            if torch.sigmoid(q_value).mean() > 0.9 and step > 0:
                break
        
        avg_loss = total_loss / actual_steps
        avg_loss.backward()
        optimizer.step()
        
        epoch_loss += avg_loss.item()
        
        # Log batch loss to W&B
        wandb.log({"batch_loss": avg_loss.item()})
        progress_bar.set_postfix({"avg_loss": epoch_loss / (i + 1)})

    avg_epoch_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} finished. Average Training Loss: {avg_epoch_loss}")
    
    # --- EVALUATION PHASE ---
def evaluate(embedding_layer, model, output_head, test_loader, device):
    embedding_layer.eval()
    model.eval()
    output_head.eval()
    q_head.eval() # Also set q_head to eval mode

    total_puzzles = 0
    correct_puzzles = 0
    total_digits = 0
    correct_digits = 0

    with torch.no_grad():
        for quiz_tokens, solution_tokens in tqdm(test_loader, desc="Evaluating"):
            quiz_tokens = quiz_tokens.to(device)
            solution_tokens = solution_tokens.to(device)

            # --- Evaluation follows the same deep recursion logic ---
            x_embedded = embedding_layer(quiz_tokens)
            y_embedded = torch.zeros_like(x_embedded, device=device)
            z_embedded = torch.zeros_like(x_embedded, device=device)
            
            final_y_pred_tokens = None

            # The paper uses the full N_sup steps at test time
            for step in range(N_SUP):
                y_refined, z_refined = deep_recursion(model, x_embedded, y_embedded, z_embedded)
                y_embedded, z_embedded = y_refined.detach(), z_refined.detach()
                
                # We can check the final prediction after all steps
                if step == N_SUP - 1:
                    y_logits = output_head(y_refined)
                    final_y_pred_tokens = torch.argmax(y_logits, dim=-1)

            # Compare the final prediction with the solution
            correct_puzzles_batch = torch.all(final_y_pred_tokens == solution_tokens, dim=1).sum().item()
            correct_digits_batch = (final_y_pred_tokens == solution_tokens).sum().item()
            
            correct_puzzles += correct_puzzles_batch
            total_puzzles += quiz_tokens.size(0)
            correct_digits += correct_digits_batch
            total_digits += quiz_tokens.size(0) * SEQ_LEN

    puzzle_accuracy = (correct_puzzles / total_puzzles) * 100
    digit_accuracy = (correct_digits / total_digits) * 100
    
    return puzzle_accuracy, digit_accuracy
    
    print("Running evaluation...")
    puzzle_acc, digit_acc = evaluate(embedding_layer, trm_model, output_head, test_dataloader, device)
    
    print(f"Evaluation Results - Puzzle Acc: {puzzle_acc:.2f}%, Digit Acc: {digit_acc:.2f}%")

    # Log epoch-level metrics to W&B
    wandb.log({
        "epoch": epoch + 1,
        "avg_train_loss": avg_epoch_loss,
        "puzzle_accuracy": puzzle_acc,
        "digit_accuracy": digit_acc,
    })

print("Training finished!")

# --- FINISH W&B RUN ---
wandb.finish()

--- Epoch 1/2 ---


Training Epoch 1: 100%|██████████| 63/63 [05:11<00:00,  4.94s/it, avg_loss=1.53]


Epoch 1 finished. Average Training Loss: 1.5275759034686618
--- Epoch 2/2 ---


Training Epoch 2: 100%|██████████| 63/63 [05:12<00:00,  4.95s/it, avg_loss=1.51]

Epoch 2 finished. Average Training Loss: 1.5147268147695632
Training finished!





0,1
batch_loss,█▇▆▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch_loss,1.4997


NUM_EPOCHS = 3 # Let's start with 3 epochs for a quick test run.

# Set models to training mode
embedding_layer.train()
trm_model.train()
output_head.train()
q_head.train()

for epoch in range(NUM_EPOCHS):
    print(f"--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    
    # Use tqdm for a progress bar
    for quiz_tokens, solution_tokens in tqdm(train_dataloader, desc="Training"):
        quiz_tokens = quiz_tokens.to(device)
        solution_tokens = solution_tokens.to(device)
        
        # --- Start of Deep Supervision ---
        
        # 1. Initialize states for the deep supervision loop
        # Embed the input puzzle once per supervision cycle
        x_embedded = embedding_layer(quiz_tokens)

        # Initialize y (prediction) and z (latent) embeddings as zero tensors
        # The paper calls this y_init and z_init
        y_embedded = torch.zeros_like(x_embedded, device=device)
        z_embedded = torch.zeros_like(x_embedded, device=device)

        # Reset gradients for the new supervision cycle
        optimizer.zero_grad()
        
        total_loss = 0

        for step in range(N_SUP):
            # 2. Perform one full deep recursion step (T-1 no-grad, 1 with-grad)
            y_refined, z_refined = deep_recursion(trm_model, x_embedded, y_embedded, z_embedded)
            
            # Get the model's prediction logits and halting value
            y_logits = output_head(y_refined)
            q_value = q_head(y_refined) # Halting probability logit

            # 3. Calculate Losses
            # Reshape for CrossEntropyLoss: (Batch * SeqLen, VocabSize)
            prediction_loss = prediction_loss_fn(
                y_logits.view(-1, VOCAB_SIZE),
                solution_tokens.view(-1)
            )

            # ACT halting loss
            with torch.no_grad():
                # Get the actual predictions by finding the max logit
                y_pred_tokens = torch.argmax(y_logits, dim=-1)
                # The halt target is 1 if the entire puzzle is correct, 0 otherwise.
                is_correct = torch.all(y_pred_tokens == solution_tokens, dim=1).float()
            
            # The paper's pseudocode uses a simple BCE loss for halting
            halt_loss = act_loss_fn(q_value.squeeze(), is_correct)

            # Combine losses
            step_loss = prediction_loss + halt_loss
            total_loss += step_loss

            # 4. Detach states for the next iteration (THIS IS THE KEY MEMORY SAVING STEP)
            y_embedded = y_refined.detach()
            z_embedded = z_refined.detach()

            # 5. ACT Early Stopping
            # If the model is confident enough to halt, we stop the supervision loop for this batch.
            # We use sigmoid to convert the logit to a probability
            if torch.sigmoid(q_value).mean() > 0.9 and step > 0: # Check if average halt prob is high
                break
        
        # 6. Backpropagation
        # The accumulated loss from all steps is backpropagated at once.
        total_loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} finished. Last batch total loss: {total_loss.item()}")

print("Training finished!")

def evaluate(embedding_layer, model, output_head, test_loader, device):
    embedding_layer.eval()
    model.eval()
    output_head.eval()
    q_head.eval() # Also set q_head to eval mode

    total_puzzles = 0
    correct_puzzles = 0
    total_digits = 0
    correct_digits = 0

    with torch.no_grad():
        for quiz_tokens, solution_tokens in tqdm(test_loader, desc="Evaluating"):
            quiz_tokens = quiz_tokens.to(device)
            solution_tokens = solution_tokens.to(device)

            # --- Evaluation follows the same deep recursion logic ---
            x_embedded = embedding_layer(quiz_tokens)
            y_embedded = torch.zeros_like(x_embedded, device=device)
            z_embedded = torch.zeros_like(x_embedded, device=device)
            
            final_y_pred_tokens = None

            # The paper uses the full N_sup steps at test time
            for step in range(N_SUP):
                y_refined, z_refined = deep_recursion(model, x_embedded, y_embedded, z_embedded)
                y_embedded, z_embedded = y_refined.detach(), z_refined.detach()
                
                # We can check the final prediction after all steps
                if step == N_SUP - 1:
                    y_logits = output_head(y_refined)
                    final_y_pred_tokens = torch.argmax(y_logits, dim=-1)

            # Compare the final prediction with the solution
            correct_puzzles_batch = torch.all(final_y_pred_tokens == solution_tokens, dim=1).sum().item()
            correct_digits_batch = (final_y_pred_tokens == solution_tokens).sum().item()
            
            correct_puzzles += correct_puzzles_batch
            total_puzzles += quiz_tokens.size(0)
            correct_digits += correct_digits_batch
            total_digits += quiz_tokens.size(0) * SEQ_LEN

    puzzle_accuracy = (correct_puzzles / total_puzzles) * 100
    digit_accuracy = (correct_digits / total_digits) * 100
    
    return puzzle_accuracy, digit_accuracy

# --- Run Evaluation ---
print("\nStarting evaluation on the test set...")
puzzle_acc, digit_acc = evaluate(embedding_layer, trm_model, output_head, test_dataloader, device)

print(f"\n--- Evaluation Results ---")
print(f"Puzzle Accuracy: {puzzle_acc:.2f}%")
print(f"Digit Accuracy: {digit_acc:.2f}%")