In [None]:
# Import necessary libraries
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pydantic
import os

# Import the models and dataset
from dataset.sudoku import SudokuDataset, SudokuAdapter

# Configuration classes
class HRMConfig(pydantic.BaseModel):
    input_dim: int = 512
    hidden_dim: int = 512
    num_layers: int = 4
    dropout: float = 0.1
    output_dim: int = 512
    N: int = 2  # number of high-level module cycles
    T: int = 4  # number of low-level module cycles
    max_seq_len: int = 81

class ModelConfig(pydantic.BaseModel):
    learning_rate: float = 0.001
    batch_size: int = 32
    max_epochs: int = 50
    weight_decay: float = 0.01
    patience: int = 10

# Recurrent Module for HRM
class RecurrentModule(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 128, num_layers: int = 2, dropout: float = 0.1):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.layers = nn.ModuleList([
            nn.LSTM(input_size=input_dim, hidden_size=input_dim,
                   num_layers=num_layers, batch_first=True, dropout=dropout)
            for _ in range(num_layers)
        ])
        
        self.projection = nn.Linear(input_dim, input_dim, bias=False)
        self.layer_norm = nn.LayerNorm(input_dim)

    def forward(self, x, hidden=None):
        for layer in self.layers:
            x, hidden = layer(x, hidden)
        
        output = self.layer_norm(x)
        output = self.projection(output)
        return output

# Hierarchical Reasoning Model
class HierarchicalReasoningModel(nn.Module):
    def __init__(self, config: HRMConfig, device: torch.device):
        super().__init__()
        self.config = config
        self.total_steps = config.N * config.T
        self.device = device
        self.N = config.N
        self.T = config.T

        # Input projection
        self.input_proj = nn.Linear(config.input_dim, config.hidden_dim)

        # High and low level networks
        self.High_net = RecurrentModule(
            input_dim=config.input_dim,
            num_layers=config.num_layers,
            hidden_dim=config.hidden_dim,
            dropout=config.dropout
        )

        self.Low_net = RecurrentModule(
            input_dim=config.input_dim,
            num_layers=config.num_layers,
            hidden_dim=config.hidden_dim,
            dropout=config.dropout
        )

        # Combine and project to latent
        self.layer_norm = nn.LayerNorm(config.hidden_dim * 2)
        self.output_proj = nn.Sequential(
            nn.Linear(config.hidden_dim * 2, config.hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim, config.hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim // 2, config.output_dim)
        )

        # Projections
        self.low_level_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
        self.high_level_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)

    def initialize_hidden_states(self, batch_size: int):
        z0_L = torch.zeros(batch_size, 1, self.config.hidden_dim, device=self.device)
        z0_H = torch.zeros(batch_size, 1, self.config.hidden_dim, device=self.device)
        return z0_H, z0_L

    def level_step(self, first_level, second_level, input_embedding, network, projection):
        level_influence = projection(second_level)
        combined = first_level + level_influence + input_embedding
        for layer in network.layers:
            combined, _ = layer(combined)
        return combined

    def forward(self, x, hidden_states=None):
        # x: (B, 81, input_dim)
        x = self.input_proj(x)
        
        if hidden_states is None:
            high_level_state, low_level_state = self.initialize_hidden_states(x.shape[0])
        else:
            high_level_state, low_level_state = hidden_states

        # Multi-step reasoning with gradient checkpointing
        with torch.no_grad():
            for step in range(self.total_steps - 1):
                low_level_state = self.level_step(
                    low_level_state, high_level_state, x, self.Low_net, self.low_level_proj
                )

                if (step + 1) % self.T == 0:
                    high_level_state = self.level_step(
                        high_level_state, low_level_state, x, self.High_net, self.high_level_proj
                    )

        # Final step with gradient
        low_level_state = self.level_step(
            low_level_state, high_level_state, x, self.Low_net, self.low_level_proj
        )
        high_level_state = self.level_step(
            high_level_state, low_level_state, x, self.High_net, self.high_level_proj
        )

        # Combine both levels
        combined = torch.cat([low_level_state, high_level_state], dim=-1)
        combined = self.layer_norm(combined)

        latent = self.output_proj(combined)  # (B, 81, output_dim)
        return latent

print("✅ Model architectures loaded successfully!")

In [None]:
# Main training execution
def main_training_pipeline():
    """Complete training pipeline for SudokuAdapter and HRM model"""
    
    # Set device
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    print(f"🖥️  Using device: {device}")
    
    # Check if we have existing data
    data_path = "../data/sudoku_dataset.npy"
    if not os.path.exists(data_path):
        print("📊 No existing dataset found. Generating new Sudoku dataset...")
        # Generate dataset using our previous function
        dataset = generate_sudoku_field(1000, 5)  # 1000 samples, difficulty 5
        np.save(data_path, dataset)
        print(f"💾 Dataset saved to {data_path}")
    
    # Load data
    print("📁 Loading Sudoku dataset...")
    puzzles, solutions = load_sudoku_data(data_path, max_samples=1000)
    print(f"📊 Loaded {len(puzzles)} puzzle-solution pairs")
    
    # Create train/test split
    train_puzzles, train_solutions, test_puzzles, test_solutions = create_train_test_split(
        puzzles, solutions, test_size=0.2
    )
    
    print(f"🔄 Data split: {len(train_puzzles)} training, {len(test_puzzles)} testing")
    
    # Create datasets
    train_dataset = SudokuDataset(train_puzzles, train_solutions)
    test_dataset = SudokuDataset(test_puzzles, test_solutions)
    
    # Initialize configurations
    hrm_config = HRMConfig(
        input_dim=512,
        output_dim=512,
        hidden_dim=512,
        num_layers=4,
        dropout=0.1,
        N=2,
        T=4
    )
    
    model_config = ModelConfig(
        learning_rate=0.001,
        batch_size=32,
        max_epochs=30,
        weight_decay=0.01,
        patience=10
    )
    
    print("🏗️  Initializing models...")
    
    # Initialize models
    model = HierarchicalReasoningModel(config=hrm_config, device=device)
    adapter = SudokuAdapter(hidden_dim=256, hrm_input_dim=512, hrm_output_dim=512)
    
    # Initialize trainer
    trainer = HRMTrainer(model, adapter, config=model_config, device=device)
    
    print(f"🎯 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"🎯 Adapter parameters: {sum(p.numel() for p in adapter.parameters()):,}")
    
    # Start training
    print("\\n🚀 Starting training process...")
    trainer.train(train_dataset, val_dataset=test_dataset)
    
    # Plot training history
    print("\\n📈 Plotting training history...")
    trainer.plot_training_history()
    
    # Load best model for evaluation
    if os.path.exists('../results/best_model.pt'):
        print("🔄 Loading best model for evaluation...")
        checkpoint = torch.load('../results/best_model.pt', map_location=device)
        model.load_state_dict(checkpoint['model'])
        adapter.load_state_dict(checkpoint['adapter'])
        print(f"✅ Loaded model from epoch {checkpoint['epoch']} with val_loss {checkpoint['val_loss']:.4f}")
    
    # Comprehensive evaluation
    print("\\n🔍 Running comprehensive evaluation...")
    metrics = evaluate_model(model, adapter, test_dataset, device)
    
    # Visualize some predictions
    print("\\n🎨 Visualizing predictions...")
    visualize_predictions(model, adapter, test_dataset, device, num_examples=3)
    
    return model, adapter, trainer, metrics

# Alternative: Load pre-trained model if available
def load_pretrained_model(device):
    """Load a pre-trained model if available"""
    model_path = '../results/best_model.pt'
    
    if not os.path.exists(model_path):
        print("❌ No pre-trained model found. Please train the model first.")
        return None, None
    
    # Initialize configurations (must match training config)
    hrm_config = HRMConfig(
        input_dim=512,
        output_dim=512,
        hidden_dim=512,
        num_layers=4,
        dropout=0.1,
        N=2,
        T=4
    )
    
    # Initialize models
    model = HierarchicalReasoningModel(config=hrm_config, device=device)
    adapter = SudokuAdapter(hidden_dim=256, hrm_input_dim=512, hrm_output_dim=512)
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    adapter.load_state_dict(checkpoint['adapter'])
    
    model.to(device).eval()
    adapter.to(device).eval()
    
    print(f"✅ Loaded pre-trained model from epoch {checkpoint['epoch']}")
    print(f"📊 Validation loss: {checkpoint['val_loss']:.4f}")
    
    return model, adapter

print("✅ Training pipeline ready! Run main_training_pipeline() to start training.")

In [None]:
# Execute the training pipeline
print("🎯 Starting SudokuAdapter + HRM Training Pipeline")
print("=" * 60)

# Option 1: Train a new model
model, adapter, trainer, metrics = main_training_pipeline()

print("\\n🎉 Training completed!")
print("📊 Final Metrics:")
for key, value in metrics.items():
    print(f"  {key}: {value:.4f}")