HRM Training Process

load sudoku data

In [41]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import os
import pydantic 

from typing import Optional

from dataset.sudoku import SudokuDataset

from typing import Tuple, List, Dict, Optional

import dataset.sudoku as sudoku

def load_sudoku_data(data_path: str, max_samples: int = 1000):
    # Load the dictionary that was saved
    data = np.load(data_path, allow_pickle=True).item()
 
    puzzles = []
    solutions = []

    # Extract puzzles and solutions from the dictionary
    for i in range(min(max_samples, len(data))):
        if i in data:
            puzzles.append(data[i]["puzzle"])
            solutions.append(data[i]["solution"])

    return np.array(puzzles), np.array(solutions)

# presenting the data
puzzles, solutions = load_sudoku_data("./data/sudoku_train.npy")
test_puzzles, test_solutions = load_sudoku_data("./data/sudoku_test.npy")

print(test_puzzles.shape)
print(test_solutions.shape)

sudoku.display_puzzle_pair(puzzles[0].reshape(9, 9), solutions[0].reshape(9, 9))

(1000, 81)
(1000, 81)

INPUT (_ = blank)        SOLUTION
  0 1 2 3 4 5 6 7 8      0 1 2 3 4 5 6 7 8
  -----------------      -----------------
0| _ _ _ _ 4 _ 9 _ 5    0| 7 3 2 8 4 6 9 1 5
1| _ 8 _ _ _ 1 2 _ _    1| 4 8 9 5 3 1 2 7 6
2| 5 _ _ 2 _ _ 3 4 _    2| 5 1 6 2 7 9 3 4 8
3| _ 7 8 4 _ _ _ 2 _    3| 1 7 8 4 6 3 5 2 9
4| _ _ _ _ _ _ _ _ _    4| 9 5 4 1 8 2 6 3 7
5| 6 _ _ _ _ _ 4 8 1    5| 6 2 3 9 5 7 4 8 1
6| _ _ _ 7 _ _ 8 _ _    6| 3 9 5 7 1 4 8 6 2
7| _ 6 _ _ _ _ 7 _ _    7| 8 6 1 3 2 5 7 9 4
8| _ _ _ 6 9 8 _ _ _    8| 2 4 7 6 9 8 1 5 3

Statistics: 25 filled, 56 blank cells


In [42]:


class RecurrentModule(nn.Module):
    def __init__(
        self, 
        input_dim: int, 
        hidden_dim: int = 128,
        num_layers: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # going with lstm to simplify understanding 
        self.layers = nn.ModuleList(
            nn.LSTM(input_size=input_dim, hidden_size=input_dim,
                   num_layers=num_layers, batch_first=True, dropout=dropout) for _ in range(num_layers)
        )

        self.projection = nn.Linear(input_dim, input_dim, bias=False)

        self.layer_norm = nn.LayerNorm(input_dim)

    def forward(self, x, hidden=None) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape

        # Initialize hidden state if not provided
        if hidden is None:
            h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=x.device)
            c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=x.device)
            hidden = (h0, c0)

        for layer in self.layers:
            x, hidden = layer(x, hidden)

        # should add attention?

        output = self.layer_norm(x)

        output = self.projection(x)

        return output

In [None]:
# HierarchicalReasoningModel
 # added to ensure HRMConfig is defined successfully
class HRMConfig(pydantic.BaseModel):
    input_dim: int = 64
    hidden_dim: int = 128
    num_layers: int = 4
    dropout: float = 0.1
    output_dim: int = 10

    N: int = 2  # number of high-level module cycles
    T: int = 4  # number of low-level module cycles
    max_seq_len: int = 256

class ModelConfig(pydantic.BaseModel):
    learning_rate: float = 0.001
    batch_size: int = 32
    max_epochs: int = 200
    embeddings_lr: float = 0.001
    weight_decay: float = 1.0

class HierarchicalReasoningModel(nn.Module):
    def __init__(self, config: HRMConfig, device: torch.device):
        super().__init__()
        self.config = config
        self.total_steps = config.N * config.T  # total steps in the HRM
        self.device = device
        self.N = config.N
        self.T = config.T
        # Define model layers here

        # Input projection (project puzzle embedding dim -> hidden)
        self.input_proj = nn.Linear(config.input_dim, config.hidden_dim)

        self.High_net = RecurrentModule(
            input_dim=self.config.input_dim,
            num_layers=self.config.num_layers,
            hidden_dim=self.config.hidden_dim,
            dropout=self.config.dropout
        )

        self.Low_net = RecurrentModule(
            input_dim=self.config.input_dim,
            num_layers=self.config.num_layers,
            hidden_dim=self.config.hidden_dim,
            dropout=self.config.dropout
        )

        # Combine and project to latent (hrm latent == output_dim)
        self.layer_norm = nn.LayerNorm(self.config.hidden_dim * 2)  # added
        self.output_proj = nn.Sequential(
            nn.Linear(self.config.hidden_dim * 2, self.config.hidden_dim),
            nn.ReLU(),
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.hidden_dim, self.config.hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.hidden_dim // 2, self.config.output_dim)
        )

        # Projections
        self.low_level_proj = nn.Linear(self.config.hidden_dim, self.config.hidden_dim, bias=False)
        self.high_level_proj = nn.Linear(self.config.hidden_dim, self.config.hidden_dim, bias=False)

    def initialize_hidden_states(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # Initialize hidden states for low and high level modules
        z0_L = torch.zeros(batch_size, 1,  self.config.hidden_dim, device=self.device)
        z0_H = torch.zeros(batch_size, 1, self.config.hidden_dim, device=self.device)
        return z0_H, z0_L


    def level_step(self,
                   first_level: torch.Tensor,
                   second_level: torch.Tensor,
                   input_embedding: torch.Tensor,
                   network: nn.Module,
                   projection: nn.Module
                   ) -> Tuple[torch.Tensor, torch.Tensor]:
        level_influence = projection(second_level)
        combined = first_level + level_influence + input_embedding
        for layer in network.layers:
            combined, _ = layer(combined)
        return combined


    def forward(self, x, hidden_states=None):
        # x: (B, 81, input_dim)
        x = self.input_proj(x)
        # Initialize hidden states if not provided
        if hidden_states is None:
            high_level_state, low_level_state = self.initialize_hidden_states(x.shape[0])
        else:
            high_level_state, low_level_state = hidden_states

        with torch.no_grad():
            for step in range(self.total_steps - 1):
                low_level_state = self.level_step(
                    low_level_state, high_level_state, x, self.Low_net, self.low_level_proj
                )

                if(step + 1) % self.T == 0:
                    high_level_state = self.level_step(
                        high_level_state, low_level_state, x, self.High_net, self.high_level_proj
                    )

        # 1 step with gradient
        low_level_state = self.level_step(
            low_level_state, high_level_state, x, self.Low_net, self.low_level_proj
        )
        high_level_state = self.level_step(
            high_level_state, low_level_state, x, self.High_net, self.high_level_proj
        )

        combined = torch.cat([low_level_state, high_level_state], dim=-1)
        combined = self.layer_norm(combined)

        latent = self.output_proj(combined)  # (B, 81, output_dim)
        return latent

In [44]:

class ModelConfig(pydantic.BaseModel):
    learning_rate: float = 0.001
    batch_size: int = 32
    max_epochs: int = 200
    embeddings_lr: float = 0.001
    weight_decay: float = 1.0



In [45]:



class HRMTrainer:
    """
    Trainer class for the Hierarchical Reasoning Model.
    
    """
    
    def __init__(self, 
                 model: HierarchicalReasoningModel, 
                 adapter: sudoku.SudokuAdapter, 
                 config=None, device=None):
        
        self.model = model
        self.adapter = adapter
        self.config = config or ModelConfig()
        self.device = device or torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

        self.model.to(self.device)
        self.adapter.to(self.device)
        
        # Training components
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(
            list(self.model.parameters()) + list(self.adapter.parameters()),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5
        )
        
        # Training state
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.best_val_loss = float('inf')
        self.best_model_state = None
        self.epochs_without_improvement = 0
        
        # Early stopping params
        self.patience = 15
        
        # Create results directory
        os.makedirs('results', exist_ok=True)

    def _accuracy(self, logits: torch.Tensor, targets: torch.Tensor) -> float:
        # logits: (B, 81, C) targets: (B,81)
        preds = logits.argmax(dim=-1)
        correct = (preds == targets).float().sum().item()
        total = targets.numel()
        return correct / total

    def _run_epoch(self, loader: DataLoader, train: bool = True):
        if train:
            self.model.train()
            self.adapter.train()
        else:
            self.model.eval()
            self.adapter.eval()
        epoch_loss = 0.0
        epoch_acc = 0.0
        total_batches = 0
        for batch in loader:
            puzzles, solutions = batch["puzzle"], batch["solution"]  # expecting dataset returns (input, target)
            puzzles = puzzles.to(self.device)  # (B,81)
            solutions = solutions.to(self.device)  # (B,81)

            if train:
                self.optimizer.zero_grad()

            puzzle_embeds = self.adapter.encode_puzzle(puzzles)
            # Model forward -> latent (B,81,output_dim)
            # Insert before latent = self.model(puzzle_embeds)
            latent = self.model(puzzle_embeds)
            # Decode to logits (B,81,10)
            logits = self.adapter.decoder(latent)

            loss = self.criterion(logits.view(-1, logits.size(-1)), solutions.view(-1))
            acc = self._accuracy(logits, solutions)

            if train:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                torch.nn.utils.clip_grad_norm_(self.adapter.parameters(), 1.0)
                self.optimizer.step()

            epoch_loss += loss.item()
            epoch_acc += acc
            total_batches += 1
        return epoch_loss / max(1, total_batches), epoch_acc / max(1, total_batches)

    def train(self, train_dataset, val_dataset=None):
        epochs = self.config.max_epochs
        train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.config.batch_size) if val_dataset is not None else None

        for epoch in range(1, epochs + 1):
            train_loss, train_acc = self._run_epoch(train_loader, train=True)
            self.train_losses.append(train_loss)
            self.train_accuracies.append(train_acc)

            if val_loader is not None:
                with torch.no_grad():
                    val_loss, val_acc = self._run_epoch(val_loader, train=False)
                self.val_losses.append(val_loss)
                self.val_accuracies.append(val_acc)
                self.scheduler.step(val_loss)
            else:
                val_loss, val_acc = train_loss, train_acc  # fallback

            improved = val_loss < self.best_val_loss - 1e-5
            if improved:
                self.best_val_loss = val_loss
                self.best_model_state = {
                    'model': self.model.state_dict(),
                    'adapter': self.adapter.state_dict(),
                    'epoch': epoch,
                    'val_loss': val_loss
                }
                self.epochs_without_improvement = 0
            else:
                self.epochs_without_improvement += 1

            print(f"Epoch {epoch:03d} | Train Loss {train_loss:.4f} Acc {train_acc:.4f} | Val Loss {val_loss:.4f} Acc {val_acc:.4f} | LR {self.optimizer.param_groups[0]['lr']:.2e}")

            if self.epochs_without_improvement >= self.patience:
                print("Early stopping triggered.")
                break

        # Save best checkpoint
        if self.best_model_state is not None:
            torch.save(self.best_model_state, 'results/best_model.pt')
            print(f"Best model (val_loss={self.best_model_state['val_loss']:.4f}) saved to results/best_model.pt")

    def evaluate(self, dataset):
        loader = DataLoader(dataset, batch_size=self.config.batch_size)
        with torch.no_grad():
            loss, acc = self._run_epoch(loader, train=False)
        print(f"Eval Loss {loss:.4f} Acc {acc:.4f}")
        return loss, acc




In [51]:

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

config = HRMConfig(
    input_dim=512,
    output_dim=512,
    hidden_dim=512,
    num_layers=4,
    dropout=0.1,
    
)

model = HierarchicalReasoningModel(config=config, device=device)

adapter = sudoku.SudokuAdapter(hidden_dim=256, hrm_input_dim=512, hrm_output_dim=512)

trainer = HRMTrainer(model, adapter, config=ModelConfig(max_epochs=10 ),  device=device)

train_dataset = sudoku.SudokuDataset(puzzles, solutions)
val_dataset = sudoku.SudokuDataset(test_puzzles, test_solutions)


trainer.train(train_dataset, val_dataset=val_dataset)

  'puzzle': torch.tensor(self.puzzles[idx], dtype=torch.long),
  'solution': torch.tensor(self.solutions[idx], dtype=torch.long)


Epoch 001 | Train Loss 2.7467 Acc 0.0963 | Val Loss 2.4840 Acc 0.1126 | LR 1.00e-03
Epoch 002 | Train Loss 2.5664 Acc 0.1144 | Val Loss 2.3825 Acc 0.1115 | LR 1.00e-03
Epoch 003 | Train Loss 2.4543 Acc 0.1090 | Val Loss 2.2974 Acc 0.1119 | LR 1.00e-03
Epoch 004 | Train Loss 2.3390 Acc 0.1223 | Val Loss 2.2459 Acc 0.1107 | LR 1.00e-03
Epoch 005 | Train Loss 2.3130 Acc 0.1119 | Val Loss 2.2224 Acc 0.1114 | LR 1.00e-03
Epoch 006 | Train Loss 2.2689 Acc 0.1142 | Val Loss 2.2117 Acc 0.1115 | LR 1.00e-03
Epoch 007 | Train Loss 2.2524 Acc 0.1169 | Val Loss 2.2121 Acc 0.1118 | LR 1.00e-03
Epoch 008 | Train Loss 2.2368 Acc 0.1026 | Val Loss 2.2097 Acc 0.1103 | LR 1.00e-03
Epoch 009 | Train Loss 2.2190 Acc 0.1236 | Val Loss 2.2063 Acc 0.1114 | LR 1.00e-03
Epoch 010 | Train Loss 2.2159 Acc 0.1130 | Val Loss 2.2061 Acc 0.1113 | LR 1.00e-03
Best model (val_loss=2.2061) saved to results/best_model.pt


In [50]:
import torch, numpy as np
from torch.utils.data import DataLoader
from dataset.sudoku import SudokuDataset, SudokuAdapter  # adjust if path differs

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

# --- 1. Rebuild config / model / adapter exactly as during training ---
hrm_config = HRMConfig(
    input_dim=512,
    output_dim=512,
    hidden_dim=512,
    num_layers=4,
    dropout=0.1,
)
model = HierarchicalReasoningModel(config=hrm_config, device=device)
adapter = SudokuAdapter(hidden_dim=256, hrm_input_dim=512, hrm_output_dim=512)

ckpt = torch.load('results/best_model.pt', map_location=device)
model.load_state_dict(ckpt['model'])
adapter.load_state_dict(ckpt['adapter'])
model.to(device).eval()
adapter.to(device).eval()

# --- 2. Prepare test dataset (reuse already loaded arrays or reload) ---
test_puzzles, test_solutions = load_sudoku_data("./data/sudoku_test.npy")
test_ds = SudokuDataset(test_puzzles, test_solutions)
test_loader = DataLoader(test_ds, batch_size=64)

# --- 3. Metrics helpers ---
def is_valid_sudoku(grid9):
    # grid9: (9,9) ints 1-9
    rows = all(set(r) == set(range(1,10)) for r in grid9)
    cols = all(set(grid9[:,c]) == set(range(1,10)) for c in range(9))
    boxes = True
    for br in range(0,9,3):
        for bc in range(0,9,3):
            box = grid9[br:br+3, bc:bc+3].reshape(-1)
            if set(box) != set(range(1,10)):
                boxes = False
                break
    return rows and cols and boxes

total_cells = 0
correct_cells = 0
correct_given = 0
total_given = 0
solved_puzzles = 0
valid_puzzles = 0

with torch.no_grad():
    for batch in test_loader:
        puzzles = batch['puzzle'].to(device)    # (B,81) 0 means empty?
        solutions = batch['solution'].to(device)  # (B,81) digits 1..9
        embeds = adapter.encode_puzzle(puzzles)
        latent = model(embeds)
        logits = adapter.decoder(latent)  # (B,81,10) if classes 0-9
        preds = logits.argmax(-1)  # predicted digits (0-9)
        print(preds)
        # If class 0 represents empty, map 0->0 else digits already aligned; adjust if needed:
        # Assume adapter used 0..9 with 0 as blank -> convert blanks to solution guesses?
        # If during training target had digits 0..9 (with 0 for blank) adjust here.
        # If targets are 1..9 only (no 0), ensure preds==0 replaced by a guess:
        mask_blank_target = (solutions == 0)
        # Usually solutions should have 1..9; skip if not.

        correct = (preds == solutions)
        total_cells += solutions.numel()
        correct_cells += correct.sum().item()

        # Given mask (original puzzle non-zero)
        given_mask = (puzzles > 0)
        correct_given += (correct & given_mask).sum().item()
        total_given += given_mask.sum().item()

        # Whole puzzle solved?
        solved_puzzles += (correct.view(puzzles.size(0), -1).all(dim=1).sum().item())

        # Validity check on predicted full grid (fill blanks with preds)
        pred_full = preds.view(-1,9,9).cpu().numpy()
        for g in pred_full:
            if (g > 0).all() and is_valid_sudoku(g):
                valid_puzzles += 1

cell_acc = correct_cells / total_cells
given_acc = correct_given / max(1,total_given)
puzzle_acc = solved_puzzles / len(test_ds)
valid_rate = valid_puzzles / len(test_ds)

print(f"Cell Accuracy: {cell_acc:.4f}")
print(f"Given Cell Accuracy: {given_acc:.4f}")
print(f"Puzzle Solved Rate: {puzzle_acc:.4f}")
print(f"Validity Rate: {valid_rate:.4f}")

tensor([[4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        ...,
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6]], device='mps:0')
tensor([[4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        ...,
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6]], device='mps:0')
tensor([[4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        ...,
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6]], device='mps:0')
tensor([[4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        ...,
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 6]], device='mps:0')
tensor([[4, 6, 6,  ..., 6, 6, 6],
        [4, 6, 6,  ..., 6, 6, 