<a href="https://colab.research.google.com/github/weagan/Tiny-Recursive-Models/blob/main/Samsung_TRM_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Less is More: Recursive Reasoning with Tiny Networks (TRM)

This notebook implements the Tiny Recursion Model (TRM) from the paper:
"Less is More: Recursive Reasoning with Tiny Networks"

**Key Features:**
- 45% accuracy on ARC-AGI-1 with only 7M parameters
- 8% accuracy on ARC-AGI-2
- Recursive reasoning without massive models

**Paper:** https://arxiv.org/abs/2510.04871

**Original Code:** Based on Hierarchical Reasoning Model (HRM)

**Runtime Requirements:**
- ARC-AGI training: ~3 days on 4x H-100 GPUs
- Sudoku-Extreme: <36 hours on 1x L40S GPU
- Maze-Hard: <24 hours on 4x L40S GPUs

## Setup and Installation

In [1]:
# Check GPU and System Info
import torch
import subprocess

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("Number of GPUs:", torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")

PyTorch version: 2.8.0+cu126
CUDA available: True
CUDA version: 12.6
Number of GPUs: 1
GPU 0: Tesla T4
  Memory: 15.83 GB


In [2]:
# Clone Repository
!git clone https://huggingface.co/wtfmahe/Samsung-TRM
%cd Samsung-TRM

Cloning into 'Samsung-TRM'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 58 (delta 16), reused 0 (delta 0), pack-reused 4 (from 1)[K
Unpacking objects: 100% (58/58), 722.67 KiB | 2.22 MiB/s, done.
/content/Samsung-TRM


In [3]:
# Install Dependencies
# Note: Adjust PyTorch installation based on your CUDA version

# Upgrade pip and core tools
!pip install --upgrade pip wheel setuptools -q

# Install PyTorch (adjust for your CUDA version)
# For Colab with CUDA 12.x:
!pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 -q

# Install other requirements
!pip install -r requirements.txt -q

# Install adam-atan2 optimizer
!pip install --no-cache-dir --no-build-isolation adam-atan2 -q

print("✓ All dependencies installed!")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━[0m [32m1.2/1.8 MB[0m [31m36.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m33.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m75.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencie

In [4]:
import wandb
from google.colab import userdata

# Option 1: Login interactively
#wandb.login()

# Option 2: Login with API key from Colab secrets
wandb.login(key=userdata.get('W&B_Key'))

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mweagan[0m ([33mweagan-abc[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Dataset Preparation

In [5]:
import os
from google.colab import userdata

kaggle_json_content = userdata.get('KAGGLE_KEY')

# Create .kaggle directory if it doesn't exist
!mkdir -p ~/.kaggle/

# Write the Kaggle API key to kaggle.json
with open(os.path.expanduser('~/.kaggle/kaggle.json'), 'w') as f:
    f.write(kaggle_json_content)

# Set permissions for kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

print("Kaggle API configured successfully!")

Kaggle API configured successfully!


In [6]:
# Download ARC-AGI Dataset
# You'll need Kaggle API credentials configured

# Create kaggle directory
!mkdir -p kaggle/combined

# Download ARC-AGI dataset
# Note: You need to accept competition rules and have kaggle.json configured
# !kaggle competitions download -c arc-prize-2024
# !unzip -q arc-prize-2024.zip -d kaggle/combined/

# print("✓ ARC-AGI dataset downloaded!")

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [11]:
#!git clone https://github.com/fchollet/ARC


Cloning into 'ARC'...
remote: Enumerating objects: 1277, done.[K
remote: Counting objects:   0% (1/284)[Kremote: Counting objects:   1% (3/284)[Kremote: Counting objects:   2% (6/284)[Kremote: Counting objects:   3% (9/284)[Kremote: Counting objects:   4% (12/284)[Kremote: Counting objects:   5% (15/284)[Kremote: Counting objects:   6% (18/284)[Kremote: Counting objects:   7% (20/284)[Kremote: Counting objects:   8% (23/284)[Kremote: Counting objects:   9% (26/284)[Kremote: Counting objects:  10% (29/284)[Kremote: Counting objects:  11% (32/284)[Kremote: Counting objects:  12% (35/284)[Kremote: Counting objects:  13% (37/284)[Kremote: Counting objects:  14% (40/284)[Kremote: Counting objects:  15% (43/284)[Kremote: Counting objects:  16% (46/284)[Kremote: Counting objects:  17% (49/284)[Kremote: Counting objects:  18% (52/284)[Kremote: Counting objects:  19% (54/284)[Kremote: Counting objects:  20% (57/284)[Kremote: Counting objects:  21% (60/

In [8]:
!kaggle datasets download -d namank24/arc-prize-2024-dataset
!unzip -q arc-prize-2024-dataset.zip -d kaggle/combined/
#!unzip arc-prize-2024-dataset.zip


Dataset URL: https://www.kaggle.com/datasets/namank24/arc-prize-2024-dataset
License(s): apache-2.0
Downloading arc-prize-2024-dataset.zip to /content/Samsung-TRM
  0% 0.00/150k [00:00<?, ?B/s]
100% 150k/150k [00:00<00:00, 508MB/s]
replace kaggle/combined/arc-agi_evaluation_challenges.json? [y]es, [n]o, [A]ll, [N]one, [r]ename: A


In [14]:
# Build ARC-AGI-1 Dataset
# This creates augmented versions of the data

!python -m dataset.build_arc_dataset \
  --input-file-prefix kaggle/combined/arc-agi \
  --output-dir data/arc1concept-aug-1000 \
  --subsets training evaluation concept \
  --test-set-name evaluation

print("✓ ARC-AGI-1 dataset prepared!")

[Puzzle 8be77c9e] augmentation not full, only 72
[Puzzle 4258a5f9] augmentation not full, only 576
[Puzzle 3618c87e] augmentation not full, only 575
[Puzzle 2281f1f4] augmentation not full, only 576
[Puzzle 3906de3d] augmentation not full, only 576
[Puzzle aedd82e4] augmentation not full, only 576
[Puzzle 4612dd53] augmentation not full, only 575
[Puzzle 28e73c20] augmentation not full, only 72
[Puzzle f15e1fac] augmentation not full, only 576
[Puzzle 44d8ac46] augmentation not full, only 576
[Puzzle dc433765] augmentation not full, only 575
[Puzzle a5313dff] augmentation not full, only 576
[Puzzle 3f7978a0] augmentation not full, only 576
[Puzzle d4f3cd78] augmentation not full, only 576
[Puzzle 760b3cac] augmentation not full, only 576
[Puzzle cce03e0d] augmentation not full, only 576
[Puzzle ef135b50] augmentation not full, only 576
[Puzzle a3df8b1e] augmentation not full, only 72
[Puzzle 6855a6e4] augmentation not full, only 576
[Puzzle 05f2a901] augmentation not full, only 575
[Pu

In [15]:
# Build ARC-AGI-2 Dataset
# Note: Cannot train on both ARC-AGI-1 and ARC-AGI-2 together

!python -m dataset.build_arc_dataset \
  --input-file-prefix kaggle/combined/arc-agi \
  --output-dir data/arc2concept-aug-1000 \
  --subsets training2 evaluation2 concept \
  --test-set-name evaluation2

print("✓ ARC-AGI-2 dataset prepared!")

[Puzzle bd14c3bf] augmentation not full, only 574
[Puzzle 8dab14c2] augmentation not full, only 575
[Puzzle 3aa6fb7a] augmentation not full, only 576
[Puzzle e6de6e8f] augmentation not full, only 576
[Puzzle 5c0a986e] augmentation not full, only 576
[Puzzle 2697da3f] augmentation not full, only 72
[Puzzle 3618c87e] augmentation not full, only 576
[Puzzle 5168d44c] augmentation not full, only 576
[Puzzle 9bebae7a] augmentation not full, only 575
[Puzzle 90f3ed37] augmentation not full, only 576
[Puzzle 1990f7a8] augmentation not full, only 72
[Puzzle 20981f0e] augmentation not full, only 576
[Puzzle a934301b] augmentation not full, only 576
[Puzzle 18419cfa] augmentation not full, only 576
[Puzzle cce03e0d] augmentation not full, only 576
[Puzzle 4612dd53] augmentation not full, only 576
[Puzzle d37a1ef5] augmentation not full, only 576
[Puzzle f9a67cb5] augmentation not full, only 576
[Puzzle 4258a5f9] augmentation not full, only 576
[Puzzle 0b17323b] augmentation not full, only 288
[P

In [None]:
# Build Sudoku-Extreme Dataset
# Generate with 1000 examples and 1000 augmentations

!python dataset/build_sudoku_dataset.py \
  --output-dir data/sudoku-extreme-1k-aug-1000 \
  --subsample-size 1000 \
  --num-aug 1000

print("✓ Sudoku-Extreme dataset prepared!")

In [9]:
# Build Maze-Hard Dataset
# Generate 1000 examples with 8 augmentations

!python dataset/build_maze_dataset.py

print("✓ Maze-Hard dataset prepared!")

train.csv: 0.00B [00:00, ?B/s]train.csv: 1.81MB [00:00, 82.2MB/s]
100% 1000/1000 [00:00<00:00, 1404186.14it/s]
test.csv: 1.81MB [00:00, 172MB/s]
100% 1000/1000 [00:00<00:00, 1505493.18it/s]
✓ Maze-Hard dataset prepared!


## Training Experiments

### ARC-AGI Tasks

In [None]:
# Train on ARC-AGI-1 (Multi-GPU)
# Requires 4 H-100 GPUs, runs for ~3 days

run_name = "pretrain_att_arc1concept_4"

!torchrun --nproc-per-node 4 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=localhost:0 \
  --nnodes=1 \
  pretrain.py \
  arch=trm \
  data_paths="[data/arc1concept-aug-1000]" \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=4 \
  +run_name={run_name} \
  ema=True

In [None]:
# Train on ARC-AGI-1 (Single GPU - for Colab)
# Modified for Colab constraints

run_name = "pretrain_att_arc1concept_1gpu"

!python pretrain.py \
  arch=trm \
  data_paths="[data/arc1concept-aug-1000]" \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=4 \
  +run_name={run_name} \
  ema=True

In [None]:
# Train on ARC-AGI-2 (Multi-GPU)
# Requires 4 H-100 GPUs, runs for ~3 days

run_name = "pretrain_att_arc2concept_4"

!torchrun --nproc-per-node 4 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=localhost:0 \
  --nnodes=1 \
  pretrain.py \
  arch=trm \
  data_paths="[data/arc2concept-aug-1000]" \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=4 \
  +run_name={run_name} \
  ema=True

### Sudoku-Extreme Tasks

In [None]:
# Train on Sudoku-Extreme (MLP version)
# Runtime: <36 hours on 1 L40S GPU

run_name = "pretrain_mlp_t_sudoku"

!python pretrain.py \
  arch=trm \
  data_paths="[data/sudoku-extreme-1k-aug-1000]" \
  evaluators="[]" \
  epochs=50000 \
  eval_interval=5000 \
  lr=1e-4 \
  puzzle_emb_lr=1e-4 \
  weight_decay=1.0 \
  puzzle_emb_weight_decay=1.0 \
  arch.mlp_t=True \
  arch.pos_encodings=none \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=6 \
  +run_name={run_name} \
  ema=True

In [None]:
# Train on Sudoku-Extreme (Attention version)
# Runtime: <36 hours on 1 L40S GPU

run_name = "pretrain_att_sudoku"

!python pretrain.py \
  arch=trm \
  data_paths="[data/sudoku-extreme-1k-aug-1000]" \
  evaluators="[]" \
  epochs=50000 \
  eval_interval=5000 \
  lr=1e-4 \
  puzzle_emb_lr=1e-4 \
  weight_decay=1.0 \
  puzzle_emb_weight_decay=1.0 \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=6 \
  +run_name={run_name} \
  ema=True

### Maze-Hard Task

In [16]:
# Train on Maze-Hard
# Runtime: <24 hours on 4 L40S GPUs

run_name = "pretrain_att_maze30x30"

!torchrun --nproc-per-node 4 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=localhost:0 \
  --nnodes=1 \
  pretrain.py \
  arch=trm \
  data_paths="[data/maze-30x30-hard-1k]" \
  evaluators="[]" \
  epochs=50000 \
  eval_interval=5000 \
  lr=1e-4 \
  puzzle_emb_lr=1e-4 \
  weight_decay=1.0 \
  puzzle_emb_weight_decay=1.0 \
  arch.L_layers=2 \
  arch.H_cycles=3 \
  arch.L_cycles=4 \
  +run_name={run_name} \
  ema=True

W1120 05:22:02.760000 3473 torch/distributed/run.py:774] 
W1120 05:22:02.760000 3473 torch/distributed/run.py:774] *****************************************
W1120 05:22:02.760000 3473 torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1120 05:22:02.760000 3473 torch/distributed/run.py:774] *****************************************
Traceback (most recent call last):
  File "/content/Samsung-TRM/pretrain.py", line 20, in <module>
    from adam_atan2 import AdamATan2
  File "/usr/local/lib/python3.12/dist-packages/adam_atan2/__init__.py", line 1, in <module>
    from .adam_atan2 import AdamATan2
  File "/usr/local/lib/python3.12/dist-packages/adam_atan2/adam_atan2.py", line 4, in <module>
    import adam_atan2_backend
ModuleNotFoundError: No module named 'adam_atan2_backend'
Traceback (most recent call

In [15]:
# Install Dependencies
# Note: Adjust PyTorch installation based on your CUDA version

# Upgrade pip and core tools
!pip install --upgrade pip wheel setuptools -q

# Install PyTorch (adjust for your CUDA version)
# For Colab with CUDA 12.x:
!pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 -q

# Install other requirements
!pip install -r requirements.txt -q

# Install adam-atan2 optimizer
# Force pip to build from source to ensure adam_atan2_backend is compiled.
!pip install adam-atan2 --force-reinstall --no-binary :all: --verbose

print("✓ All dependencies installed!")

Using pip 25.3 from /usr/local/lib/python3.12/dist-packages/pip (python 3.12)
Collecting adam-atan2
  Using cached adam_atan2-0.0.3-py3-none-any.whl
Installing collected packages: adam-atan2
  Attempting uninstall: adam-atan2
    Found existing installation: adam_atan2 0.0.3
    Uninstalling adam_atan2-0.0.3:
      Removing file or directory /usr/local/lib/python3.12/dist-packages/adam_atan2-0.0.3.dist-info/
      Removing file or directory /usr/local/lib/python3.12/dist-packages/adam_atan2/
      Successfully uninstalled adam_atan2-0.0.3
Successfully installed adam-atan2-0.0.3
✓ All dependencies installed!


## Evaluation and Analysis

In [None]:
# Monitor Training Progress
import os
import glob

# List all run directories
runs = glob.glob("outputs/*/")
print("Available training runs:")
for run in sorted(runs):
    print(f"  {run}")

# Check latest checkpoint
latest_run = max(runs, key=os.path.getmtime) if runs else None
if latest_run:
    print(f"\nLatest run: {latest_run}")
    checkpoints = glob.glob(os.path.join(latest_run, "*.pt"))
    print(f"Checkpoints: {len(checkpoints)}")

In [None]:
# Load and Evaluate Model
import torch
from pathlib import Path

# Specify checkpoint path
checkpoint_path = "outputs/YOUR_RUN_NAME/checkpoint_best.pt"

if Path(checkpoint_path).exists():
    checkpoint = torch.load(checkpoint_path)
    print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
    print(f"Best validation accuracy: {checkpoint.get('best_val_acc', 'unknown')}")
else:
    print(f"Checkpoint not found: {checkpoint_path}")

In [None]:
# Visualize Results
import matplotlib.pyplot as plt
import numpy as np

def visualize_arc_prediction(input_grid, true_output, predicted_output):
    """Visualize ARC-AGI predictions"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes[0].imshow(input_grid, cmap='tab10', interpolation='nearest')
    axes[0].set_title('Input')
    axes[0].axis('off')

    axes[1].imshow(true_output, cmap='tab10', interpolation='nearest')
    axes[1].set_title('True Output')
    axes[1].axis('off')

    axes[2].imshow(predicted_output, cmap='tab10', interpolation='nearest')
    axes[2].set_title('Predicted Output')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

# Example usage (with dummy data)
# input_grid = np.random.randint(0, 10, (10, 10))
# true_output = np.random.randint(0, 10, (10, 10))
# predicted_output = np.random.randint(0, 10, (10, 10))
# visualize_arc_prediction(input_grid, true_output, predicted_output)

print("Visualization functions ready!")

## Citation

If you use this code, please cite:

```bibtex
@misc{jolicoeurmartineau2025morerecursivereasoningtiny,
      title={Less is More: Recursive Reasoning with Tiny Networks},
      author={Alexia Jolicoeur-Martineau},
      year={2025},
      eprint={2510.04871},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2510.04871},
}
```

And the Hierarchical Reasoning Model (HRM):

```bibtex
@misc{wang2025hierarchicalreasoningmodel,
      title={Hierarchical Reasoning Model},
      author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
      year={2025},
      eprint={2506.21734},
      archivePrefix={arXiv},
      primaryClass={cs.AI},
      url={https://arxiv.org/abs/2506.21734},
}
```