HRM Training Process

load sudoku data

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import os
import pydantic 

from typing import Optional

from dataset.sudoku import SudokuDataset

from typing import Tuple, List, Dict, Optional

import dataset.sudoku as sudoku

def load_sudoku_data(data_path: str, max_samples: int = 1000):
    # Load the dictionary that was saved
    data = np.load(data_path, allow_pickle=True).item()
 
    puzzles = []
    solutions = []

    # Extract puzzles and solutions from the dictionary
    for i in range(min(max_samples, len(data))):
        if i in data:
            puzzles.append(data[i]["puzzle"])
            solutions.append(data[i]["solution"])

    return np.array(puzzles), np.array(solutions)

# presenting the data
puzzles, solutions = load_sudoku_data("./data/sudoku_train.npy")
test_puzzles, test_solutions = load_sudoku_data("./data/sudoku_test.npy")

print(test_puzzles.shape)
print(test_solutions.shape)

sudoku.display_puzzle_pair(puzzles[0].reshape(9, 9), solutions[0].reshape(9, 9))

(1000, 81)
(1000, 81)

INPUT (_ = blank)        SOLUTION
  0 1 2 3 4 5 6 7 8      0 1 2 3 4 5 6 7 8
  -----------------      -----------------
0| _ _ _ _ 4 _ 9 _ 5    0| 7 3 2 8 4 6 9 1 5
1| _ 8 _ _ _ 1 2 _ _    1| 4 8 9 5 3 1 2 7 6
2| 5 _ _ 2 _ _ 3 4 _    2| 5 1 6 2 7 9 3 4 8
3| _ 7 8 4 _ _ _ 2 _    3| 1 7 8 4 6 3 5 2 9
4| _ _ _ _ _ _ _ _ _    4| 9 5 4 1 8 2 6 3 7
5| 6 _ _ _ _ _ 4 8 1    5| 6 2 3 9 5 7 4 8 1
6| _ _ _ 7 _ _ 8 _ _    6| 3 9 5 7 1 4 8 6 2
7| _ 6 _ _ _ _ 7 _ _    7| 8 6 1 3 2 5 7 9 4
8| _ _ _ 6 9 8 _ _ _    8| 2 4 7 6 9 8 1 5 3

Statistics: 25 filled, 56 blank cells


In [2]:
from hrm import HRMConfig, HierarchicalReasoningModel, ModelConfig

class HRMTrainer:
    """
    Trainer class for the Hierarchical Reasoning Model.
    
    """
    
    def __init__(self, 
                 model: HierarchicalReasoningModel, 
                 config=None, device=None):
        
        self.model = model
        self.config = config or ModelConfig()
        self.device = device or torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

        self.model.to(self.device)
        
        # Training components
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(
            list(self.model.parameters()),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5
        )
        
        # Training state
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.best_val_loss = float('inf')
        self.best_model_state = None
        self.epochs_without_improvement = 0
        
        # Early stopping params
        self.patience = 15
        
        # Create results directory
        os.makedirs('results', exist_ok=True)

    def _accuracy(self, logits: torch.Tensor, targets: torch.Tensor) -> float:
        # logits: (B, 81, C) targets: (B,81)
        print(logits.shape, targets.shape)
        preds = logits.argmax(dim=-1)
        correct = (preds == targets).float().sum().item()
        total = targets.numel()
        return correct / total

    # Replace your _run_epoch method with this corrected version
    def _run_epoch(self, loader: DataLoader, train: bool = True):
        
        if train:
            self.model.train()
        else:
            self.model.eval()
        
        epoch_loss = 0.0
        epoch_acc = 0.0
        total_batches = 0
        
        for batch in loader:
            puzzles, solutions = batch["puzzle"], batch["solution"]
            puzzles = puzzles.to(self.device).float()
            solutions = solutions.to(self.device).long()
            
            # Debug shapes at each step
            #print(f"Batch {total_batches}: Input {puzzles.shape}, Target {solutions.shape}")
            
            if train:
                self.optimizer.zero_grad()

            # Forward pass - add unsqueeze to make input 3D
            if puzzles.dim() == 2:
                puzzles = puzzles.unsqueeze(-1)  # (B, 81) -> (B, 81, 1)
            
            model_output = self.model(puzzles)
            #print(f"Model output shape: {model_output.shape}")

            
            # Calculate loss based on output dimensions
            batch_size, seq_len, num_classes = model_output.shape

            output_flat = model_output.view(-1, num_classes)  # (B*81, C)
            solutions_flat = solutions.view(-1)               # (B*81,)
            #print(f"  Flattened: output {output_flat.shape}, target {solutions_flat.shape}")
            
            loss = self.criterion(output_flat, solutions_flat)
            
            # FIX: Compare flattened predictions with flattened solutions
            preds = output_flat.argmax(dim=-1)  # (B*81,)
            acc = (preds == solutions_flat).float().mean().item()  # Both same shape now

            if train:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()

            epoch_loss += loss.item()
            epoch_acc += acc
            total_batches += 1
                
        return epoch_loss / max(1, total_batches), epoch_acc / max(1, total_batches)

    def train(self, train_dataset, val_dataset=None):
        epochs = self.config.max_epochs
        train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.config.batch_size) if val_dataset is not None else None

        for epoch in range(1, epochs + 1):
            train_loss, train_acc = self._run_epoch(train_loader, train=True)
            self.train_losses.append(train_loss)
            self.train_accuracies.append(train_acc)

            if val_loader is not None:
                with torch.no_grad():
                    val_loss, val_acc = self._run_epoch(val_loader, train=False)
                self.val_losses.append(val_loss)
                self.val_accuracies.append(val_acc)
                self.scheduler.step(val_loss)
            else:
                val_loss, val_acc = train_loss, train_acc  # fallback

            improved = val_loss < self.best_val_loss - 1e-5
            if improved:
                self.best_val_loss = val_loss
                self.best_model_state = {
                    'model': self.model.state_dict(),
                    'epoch': epoch,
                    'val_loss': val_loss
                }
                self.epochs_without_improvement = 0
            else:
                self.epochs_without_improvement += 1

            print(f"Epoch {epoch:03d} | Train Loss {train_loss:.4f} Acc {train_acc:.4f} | Val Loss {val_loss:.4f} Acc {val_acc:.4f} | LR {self.optimizer.param_groups[0]['lr']:.2e}")

            if self.epochs_without_improvement >= self.patience:
                print("Early stopping triggered.")
                break

        # Save best checkpoint
        if self.best_model_state is not None:
            torch.save(self.best_model_state, 'results/best_model.pt')
            print(f"Best model (val_loss={self.best_model_state['val_loss']:.4f}) saved to results/best_model.pt")

    def evaluate(self, dataset):
        loader = DataLoader(dataset, batch_size=self.config.batch_size)
        with torch.no_grad():
            loss, acc = self._run_epoch(loader, train=False)
        print(f"Eval Loss {loss:.4f} Acc {acc:.4f}")
        return loss, acc

In [None]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

hrm_config = HRMConfig(
    input_dim=1,   # Each position has 1 feature (digit value) - input is (B, 81, 1)
    output_dim=10, # 10 possible values (0-9) for each position
    hidden_dim=512, # Reasonable hidden size
    num_layers=4,
    dropout=0.1,
    N=2,  # Number of high-level cycles
    T=4   # Number of low-level cycles per high-level cycle
)

model_config = ModelConfig(
    learning_rate=0.001,
    batch_size=32,
    max_epochs=20,
    embeddings_lr=0.001,
    weight_decay=1e-4  # Much smaller weight decay
)

model = HierarchicalReasoningModel(config=hrm_config, device=device)

trainer = HRMTrainer(model, config=model_config, device=device)

train_dataset = sudoku.SudokuDataset(puzzles, solutions)
val_dataset = sudoku.SudokuDataset(test_puzzles, test_solutions)

trainer.train(train_dataset, val_dataset=val_dataset)

Epoch 001 | Train Loss 1.6376 Acc 0.3618 | Val Loss 1.5231 Acc 0.3971 | LR 1.00e-03
Epoch 002 | Train Loss 1.5139 Acc 0.3965 | Val Loss 1.5023 Acc 0.4002 | LR 1.00e-03
Epoch 002 | Train Loss 1.5139 Acc 0.3965 | Val Loss 1.5023 Acc 0.4002 | LR 1.00e-03
Epoch 003 | Train Loss 1.5032 Acc 0.3996 | Val Loss 1.5004 Acc 0.3985 | LR 1.00e-03
Epoch 003 | Train Loss 1.5032 Acc 0.3996 | Val Loss 1.5004 Acc 0.3985 | LR 1.00e-03
Epoch 004 | Train Loss 1.5015 Acc 0.4001 | Val Loss 1.4957 Acc 0.3995 | LR 1.00e-03
Epoch 004 | Train Loss 1.5015 Acc 0.4001 | Val Loss 1.4957 Acc 0.3995 | LR 1.00e-03
Epoch 005 | Train Loss 1.5049 Acc 0.3985 | Val Loss 1.4980 Acc 0.4002 | LR 1.00e-03
Epoch 005 | Train Loss 1.5049 Acc 0.3985 | Val Loss 1.4980 Acc 0.4002 | LR 1.00e-03
Epoch 006 | Train Loss 1.4968 Acc 0.3998 | Val Loss 1.4895 Acc 0.4022 | LR 1.00e-03
Epoch 006 | Train Loss 1.4968 Acc 0.3998 | Val Loss 1.4895 Acc 0.4022 | LR 1.00e-03
Epoch 007 | Train Loss 1.4935 Acc 0.4005 | Val Loss 1.4884 Acc 0.4011 | LR 1

In [4]:
# Minimal test
import torch
import torch.nn as nn

# Test basic tensor operation
test_tensor = torch.zeros(1, 81)
print(f"Test tensor shape: {test_tensor.shape}")
print(f"Test tensor unsqueezed: {test_tensor.unsqueeze(-1).shape}")

# Import the HRM module and check for issues
try:
    from hrm import HRMConfig, HierarchicalReasoningModel
    print("HRM import successful")
    
    # Create config - match the main training config
    config = HRMConfig(
        input_dim=1,      # Feature dimension per cell
        output_dim=10,    # Number of classes (0-9)
        hidden_dim=128,   # Hidden dimension
        N=2, T=4          # Hierarchical parameters
    )
    print(f"Config created: input_dim={config.input_dim}, output_dim={config.output_dim}")
    
    # Create model
    device = torch.device('cpu')  # Use CPU for simplicity
    model = HierarchicalReasoningModel(config=config, device=device)
    print("Model created successfully")
    
    # Test model forward with 2D input (will be converted to 3D internally)
    test_input = torch.zeros(1, 81, dtype=torch.float32)
    print(f"Test input shape: {test_input.shape}")
    
    with torch.no_grad():
        output = model(test_input)
        print(f"Model output shape: {output.shape}")
        print(f"Expected shape: (1, 81, 10)")
    
except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()

Test tensor shape: torch.Size([1, 81])
Test tensor unsqueezed: torch.Size([1, 81, 1])
HRM import successful
Config created: input_dim=1, output_dim=10
Model created successfully
Test input shape: torch.Size([1, 81])
Model output shape: torch.Size([1, 81, 10])
Expected shape: (1, 81, 10)


In [5]:
def validate_sudoku_predictions(model, dataset, num_samples=10):
    """
    Validate that the model's predictions form valid Sudoku solutions.
    
    Args:
        model: Trained HRM model
        dataset: SudokuDataset to validate on
        num_samples: Number of samples to validate
    
    Returns:
        dict: Validation results including accuracy metrics
    """
    model.eval()
    
    def is_valid_sudoku(grid):
        """Check if a 9x9 sudoku grid is valid"""
        # Check rows
        for row in grid:
            if len(set(row)) != 9 or not all(1 <= x <= 9 for x in row):
                return False
        
        # Check columns
        for col in range(9):
            column = [grid[row][col] for row in range(9)]
            if len(set(column)) != 9:
                return False
        
        # Check 3x3 boxes
        for box_row in range(3):
            for box_col in range(3):
                box = []
                for r in range(3):
                    for c in range(3):
                        box.append(grid[box_row*3 + r][box_col*3 + c])
                if len(set(box)) != 9:
                    return False
        
        return True
    
    def calculate_cell_accuracy(predicted, target):
        """Calculate per-cell accuracy"""
        return (predicted == target).float().mean().item()
    
    # Sample validation data
    loader = DataLoader(dataset, batch_size=1, shuffle=True)
    
    results = {
        'valid_sudokus': 0,
        'total_samples': 0,
        'cell_accuracies': [],
        'sudoku_validity_rate': 0.0,
        'average_cell_accuracy': 0.0
    }
    
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= num_samples:
                break
                
            puzzle, solution = batch["puzzle"], batch["solution"]
            puzzle = puzzle.to(device).float()
            solution = solution.to(device).long()
            
            # Get model prediction
            if puzzle.dim() == 2:
                puzzle = puzzle.unsqueeze(-1)  # Add feature dimension
            
            output = model(puzzle)  # Shape: (1, 81, 10)
            predicted = output.argmax(dim=-1)  # Shape: (1, 81)
            
            # Convert to numpy and reshape to 9x9
            pred_grid = predicted.cpu().numpy()[0].reshape(9, 9)
            true_grid = solution.cpu().numpy()[0].reshape(9, 9)
            
            # Calculate cell accuracy
            cell_acc = calculate_cell_accuracy(predicted[0], solution[0])
            results['cell_accuracies'].append(cell_acc)
            
            # Check if prediction is valid sudoku
            is_valid = is_valid_sudoku(pred_grid)
            if is_valid:
                results['valid_sudokus'] += 1
            
            results['total_samples'] += 1
            
            # Print first few examples
            if i < 3:
                print(f"\nSample {i+1}:")
                print("Puzzle:")
                puzzle_display = puzzle.cpu().numpy()[0].reshape(9, 9)
                for row in puzzle_display:
                    print([int(x) for x in row])
                
                print("Predicted Solution:")
                for row in pred_grid:
                    print(list(row))
                
                print("True Solution:")
                for row in true_grid:
                    print(list(row))
                
                print(f"Cell Accuracy: {cell_acc:.4f}")
                print(f"Valid Sudoku: {is_valid}")
    
    # Calculate final metrics
    results['sudoku_validity_rate'] = results['valid_sudokus'] / results['total_samples']
    results['average_cell_accuracy'] = sum(results['cell_accuracies']) / len(results['cell_accuracies'])
    
    print(f"\n=== Sudoku Validation Results ===")
    print(f"Total samples validated: {results['total_samples']}")
    print(f"Valid Sudoku solutions: {results['valid_sudokus']}")
    print(f"Sudoku validity rate: {results['sudoku_validity_rate']:.4f} ({results['sudoku_validity_rate']*100:.2f}%)")
    print(f"Average cell accuracy: {results['average_cell_accuracy']:.4f} ({results['average_cell_accuracy']*100:.2f}%)")
    
    return results

# Run validation
validation_results = validate_sudoku_predictions(model, val_dataset, num_samples=10)


Sample 1:
Puzzle:
[0, 0, 0, 0, 0, 0, 9, 0, 0]
[0, 0, 1, 0, 9, 0, 0, 0, 2]
[0, 0, 6, 8, 5, 0, 0, 0, 7]
[2, 8, 9, 3, 0, 6, 0, 0, 0]
[0, 1, 3, 0, 0, 0, 0, 0, 8]
[0, 0, 0, 1, 0, 9, 0, 0, 0]
[0, 0, 4, 7, 0, 8, 0, 0, 9]
[0, 0, 0, 5, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 1, 4, 0, 0, 0]
Predicted Solution:
[np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(3), np.int64(0), np.int64(0)]
[np.int64(0), np.int64(0), np.int64(6), np.int64(0), np.int64(3), np.int64(0), np.int64(0), np.int64(0), np.int64(2)]
[np.int64(0), np.int64(0), np.int64(3), np.int64(3), np.int64(3), np.int64(0), np.int64(0), np.int64(0), np.int64(3)]
[np.int64(2), np.int64(3), np.int64(3), np.int64(2), np.int64(0), np.int64(3), np.int64(0), np.int64(0), np.int64(0)]
[np.int64(0), np.int64(6), np.int64(2), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(3)]
[np.int64(0), np.int64(0), np.int64(0), np.int64(6), np.int64(0), np.int64(3), np.int64(0), np.int64(0), np.int64(0)]
[