<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/trm_colab_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Less is More: Recursive Reasoning with Tiny Networks (TRM)

This notebook implements the Tiny Recursion Model (TRM) from the paper:
"Less is More: Recursive Reasoning with Tiny Networks"

**Key Features:**
- 45% accuracy on ARC-AGI-1 with only 7M parameters
- 8% accuracy on ARC-AGI-2
- Recursive reasoning without massive models

**Paper:** https://arxiv.org/abs/2510.04871

**Original Code:** Based on Hierarchical Reasoning Model (HRM)

**Runtime Requirements:**
- ARC-AGI training: ~3 days on 4x H-100 GPUs
- Sudoku-Extreme: <36 hours on 1x L40S GPU
- Maze-Hard: <24 hours on 4x L40S GPUs

## Setup and Installation

In [None]:
!git clone https://github.com/weagan/Samsung-TRM

In [1]:
# Check GPU and System Info
import torch
import subprocess

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("Number of GPUs:", torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")

PyTorch version: 2.8.0+cu126
CUDA available: False


In [None]:
# Clone Repository
!git clone https://huggingface.co/wtfmahe/Samsung-TRM
%cd Samsung-TRM

In [None]:
# Install Dependencies
# Note: Adjust PyTorch installation based on your CUDA version

# Upgrade pip and core tools
!pip install --upgrade pip wheel setuptools -q

# Install PyTorch (adjust for your CUDA version)
# For Colab with CUDA 12.x:
!pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 -q

# Install other requirements
!pip install -r requirements.txt -q

# Install adam-atan2 optimizer
!pip install --no-cache-dir --no-build-isolation adam-atan2 -q

print("✓ All dependencies installed!")

In [None]:
# Login to Weights & Biases (Optional)
# Skip this cell if you don't want to use W&B

import wandb

# Option 1: Login interactively
wandb.login()

# Option 2: Login with API key (uncomment and add your key)
# wandb.login(key="YOUR_WANDB_API_KEY")

## Dataset Preparation

In [None]:
# Download ARC-AGI Dataset
# You'll need Kaggle API credentials configured

# Create kaggle directory
!mkdir -p kaggle/combined

# Download ARC-AGI dataset
# Note: You need to accept competition rules and have kaggle.json configured
!kaggle competitions download -c arc-prize-2024
!unzip -q arc-prize-2024.zip -d kaggle/combined/

print("✓ ARC-AGI dataset downloaded!")

In [None]:
# Build ARC-AGI-1 Dataset
# This creates augmented versions of the data

!python -m dataset.build_arc_dataset \
  --input-file-prefix kaggle/combined/arc-agi \
  --output-dir data/arc1concept-aug-1000 \
  --subsets training evaluation concept \
  --test-set-name evaluation

print("✓ ARC-AGI-1 dataset prepared!")

In [None]:
# Build ARC-AGI-2 Dataset
# Note: Cannot train on both ARC-AGI-1 and ARC-AGI-2 together

!python -m dataset.build_arc_dataset \
  --input-file-prefix kaggle/combined/arc-agi \
  --output-dir data/arc2concept-aug-1000 \
  --subsets training2 evaluation2 concept \
  --test-set-name evaluation2

print("✓ ARC-AGI-2 dataset prepared!")

In [None]:
# Build Sudoku-Extreme Dataset
# Generate with 1000 examples and 1000 augmentations

!python dataset/build_sudoku_dataset.py \
  --output-dir data/sudoku-extreme-1k-aug-1000 \
  --subsample-size 1000 \
  --num-aug 1000

print("✓ Sudoku-Extreme dataset prepared!")

In [None]:
# Build Maze-Hard Dataset
# Generate 1000 examples with 8 augmentations

!python dataset/build_maze_dataset.py

print("✓ Maze-Hard dataset prepared!")

## Training Experiments

### ARC-AGI Tasks

In [None]:
# Train on ARC-AGI-1 (Multi-GPU)
# Requires 4 H-100 GPUs, runs for ~3 days

run_name = "pretrain_att_arc1concept_4"

!torchrun --nproc-per-node 4 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=localhost:0 \
  --nnodes=1 \
  pretrain.py \
  arch=trm \
  data_paths="[data/arc1concept-aug-1000]" \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=4 \
  +run_name={run_name} \
  ema=True

In [None]:
# Train on ARC-AGI-1 (Single GPU - for Colab)
# Modified for Colab constraints

run_name = "pretrain_att_arc1concept_1gpu"

!python pretrain.py \
  arch=trm \
  data_paths="[data/arc1concept-aug-1000]" \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=4 \
  +run_name={run_name} \
  ema=True

In [None]:
# Train on ARC-AGI-2 (Multi-GPU)
# Requires 4 H-100 GPUs, runs for ~3 days

run_name = "pretrain_att_arc2concept_4"

!torchrun --nproc-per-node 4 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=localhost:0 \
  --nnodes=1 \
  pretrain.py \
  arch=trm \
  data_paths="[data/arc2concept-aug-1000]" \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=4 \
  +run_name={run_name} \
  ema=True

### Sudoku-Extreme Tasks

In [None]:
# Train on Sudoku-Extreme (MLP version)
# Runtime: <36 hours on 1 L40S GPU

run_name = "pretrain_mlp_t_sudoku"

!python pretrain.py \
  arch=trm \
  data_paths="[data/sudoku-extreme-1k-aug-1000]" \
  evaluators="[]" \
  epochs=50000 \
  eval_interval=5000 \
  lr=1e-4 \
  puzzle_emb_lr=1e-4 \
  weight_decay=1.0 \
  puzzle_emb_weight_decay=1.0 \
  arch.mlp_t=True \
  arch.pos_encodings=none \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=6 \
  +run_name={run_name} \
  ema=True

In [None]:
# Train on Sudoku-Extreme (Attention version)
# Runtime: <36 hours on 1 L40S GPU

run_name = "pretrain_att_sudoku"

!python pretrain.py \
  arch=trm \
  data_paths="[data/sudoku-extreme-1k-aug-1000]" \
  evaluators="[]" \
  epochs=50000 \
  eval_interval=5000 \
  lr=1e-4 \
  puzzle_emb_lr=1e-4 \
  weight_decay=1.0 \
  puzzle_emb_weight_decay=1.0 \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=6 \
  +run_name={run_name} \
  ema=True

### Maze-Hard Task

In [None]:
# Train on Maze-Hard
# Runtime: <24 hours on 4 L40S GPUs

run_name = "pretrain_att_maze30x30"

!torchrun --nproc-per-node 4 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=localhost:0 \
  --nnodes=1 \
  pretrain.py \
  arch=trm \
  data_paths="[data/maze-30x30-hard-1k]" \
  evaluators="[]" \
  epochs=50000 \
  eval_interval=5000 \
  lr=1e-4 \
  puzzle_emb_lr=1e-4 \
  weight_decay=1.0 \
  puzzle_emb_weight_decay=1.0 \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=4 \
  +run_name={run_name} \
  ema=True

## Evaluation and Analysis

In [None]:
# Monitor Training Progress
import os
import glob

# List all run directories
runs = glob.glob("outputs/*/")
print("Available training runs:")
for run in sorted(runs):
    print(f"  {run}")

# Check latest checkpoint
latest_run = max(runs, key=os.path.getmtime) if runs else None
if latest_run:
    print(f"\nLatest run: {latest_run}")
    checkpoints = glob.glob(os.path.join(latest_run, "*.pt"))
    print(f"Checkpoints: {len(checkpoints)}")

In [None]:
# Load and Evaluate Model
import torch
from pathlib import Path

# Specify checkpoint path
checkpoint_path = "outputs/YOUR_RUN_NAME/checkpoint_best.pt"

if Path(checkpoint_path).exists():
    checkpoint = torch.load(checkpoint_path)
    print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
    print(f"Best validation accuracy: {checkpoint.get('best_val_acc', 'unknown')}")
else:
    print(f"Checkpoint not found: {checkpoint_path}")

In [None]:
# Visualize Results
import matplotlib.pyplot as plt
import numpy as np

def visualize_arc_prediction(input_grid, true_output, predicted_output):
    """Visualize ARC-AGI predictions"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes[0].imshow(input_grid, cmap='tab10', interpolation='nearest')
    axes[0].set_title('Input')
    axes[0].axis('off')

    axes[1].imshow(true_output, cmap='tab10', interpolation='nearest')
    axes[1].set_title('True Output')
    axes[1].axis('off')

    axes[2].imshow(predicted_output, cmap='tab10', interpolation='nearest')
    axes[2].set_title('Predicted Output')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

# Example usage (with dummy data)
# input_grid = np.random.randint(0, 10, (10, 10))
# true_output = np.random.randint(0, 10, (10, 10))
# predicted_output = np.random.randint(0, 10, (10, 10))
# visualize_arc_prediction(input_grid, true_output, predicted_output)

print("Visualization functions ready!")

## Citation

If you use this code, please cite:

```bibtex
@misc{jolicoeurmartineau2025morerecursivereasoningtiny,
      title={Less is More: Recursive Reasoning with Tiny Networks},
      author={Alexia Jolicoeur-Martineau},
      year={2025},
      eprint={2510.04871},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2510.04871},
}
```

And the Hierarchical Reasoning Model (HRM):

```bibtex
@misc{wang2025hierarchicalreasoningmodel,
      title={Hierarchical Reasoning Model},
      author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
      year={2025},
      eprint={2506.21734},
      archivePrefix={arXiv},
      primaryClass={cs.AI},
      url={https://arxiv.org/abs/2506.21734},
}
```