In [None]:
import os
import math
import time
from typing import Tuple, List, Dict, Optional, Any, Sequence, Union
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.optimizer import Optimizer, ParamsT
from torch.optim import AdamW # Using standard AdamW
from tqdm import tqdm
import coolname
from pydantic import BaseModel


# ==============================================================================
# 2. Model Components
# ==============================================================================

def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
    with torch.no_grad():
        if std == 0:
            tensor.zero_()
        else:
            sqrt2 = math.sqrt(2)
            a = math.erf(lower / sqrt2)
            b = math.erf(upper / sqrt2)
            z = (b - a) / 2
            c = (2 * math.pi) ** -0.5
            pdf_u = c * math.exp(-0.5 * lower ** 2)
            pdf_l = c * math.exp(-0.5 * upper ** 2)
            comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
            tensor.uniform_(a, b)
            tensor.erfinv_()
            tensor.mul_(sqrt2 * comp_std)
            tensor.clip_(lower * comp_std, upper * comp_std)
    return tensor

CosSin = Tuple[torch.Tensor, torch.Tensor]

def _find_multiple(a, b):
    return (-(a // -b)) * b

def rotate_half(x: torch.Tensor):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    orig_dtype = q.dtype
    q = q.to(cos.dtype)
    k = k.to(cos.dtype)
    q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
    k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
    return q_embed.to(orig_dtype), k_embed.to(orig_dtype)

class CastedLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool):
        super().__init__()
        self.weight = nn.Parameter(trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5)))
        self.bias = nn.Parameter(torch.zeros((out_features, ))) if bias else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        bias = self.bias.to(input.dtype) if self.bias is not None else None
        return F.linear(input, self.weight.to(input.dtype), bias=bias)

class CastedEmbedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, init_std: float, cast_to: torch.dtype):
        super().__init__()
        self.cast_to = cast_to
        self.embedding_weight = nn.Parameter(trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std))
        
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.embedding(input, self.embedding_weight.to(self.cast_to))

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings, base, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
        t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
        freqs = torch.outer(t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self):
        return self.cos_cached, self.sin_cached

class Attention(nn.Module):
    def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.output_size = head_dim * num_heads
        self.num_heads = num_heads
        self.num_key_value_heads = num_key_value_heads
        self.causal = causal
        self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
        self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)

    def forward(self, cos_sin: Optional[CosSin], hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        qkv = self.qkv_proj(hidden_states)
        qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
        query = qkv[:, :, :self.num_heads]
        key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
        value = qkv[:, :, self.num_heads + self.num_key_value_heads:]

        if cos_sin is not None:
            cos, sin = cos_sin
            query, key = apply_rotary_pos_emb(query, key, cos, sin)
        
        # Using PyTorch's native scaled_dot_product_attention instead of flash-attn
        # It expects shape (batch, heads, seq_len, dim)
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        
        attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=self.causal)
        
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.output_size)
        return self.o_proj(attn_output)

class SwiGLU(nn.Module):
    def __init__(self, hidden_size: int, expansion: float):
        super().__init__()
        inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
        self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
        self.down_proj = CastedLinear(inter, hidden_size, bias=False)

    def forward(self, x):
        gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
        return self.down_proj(F.silu(gate) * up)

def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.square().mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    return hidden_states.to(input_dtype)


# ==============================================================================
# 3. Sparse Embedding
# ==============================================================================

class CastedSparseEmbedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
        super().__init__()
        self.cast_to = cast_to
        self.register_buffer('weights', trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std))
        self.register_buffer('local_weights', torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
        self.register_buffer('local_ids', torch.zeros(batch_size, dtype=torch.int32), persistent=False)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        if not self.training:
            return self.weights[inputs].to(self.cast_to)
        with torch.no_grad():
            self.local_weights.copy_(self.weights[inputs])
            self.local_ids.copy_(inputs)
        return self.local_weights.to(self.cast_to)

class CastedSparseEmbeddingSignSGD(Optimizer):
    def __init__(self, params: ParamsT, lr: float = 1e-3, weight_decay: float = 1e-2):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad
    def step(self, closure=None):
        for group in self.param_groups:
            local_weights_grad, local_ids, weights = None, None, None
            for p in group["params"]:
                if p.requires_grad: local_weights_grad = p.grad
                elif p.ndim == 1: local_ids = p
                elif p.ndim == 2: weights = p
            
            if local_weights_grad is None or local_ids is None or weights is None: continue

            grad_ids, inv = local_ids.unique(return_inverse=True)
            grad = torch.zeros((grad_ids.shape[0], local_weights_grad.shape[1]), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
            grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, local_weights_grad.shape[1]), local_weights_grad)
            
            p = weights[grad_ids]
            p.mul_(1.0 - group['lr'] * group['weight_decay']).add_(torch.sign(grad), alpha=-group['lr'])
            weights[grad_ids] = p

# ==============================================================================
# 4. Main Model
# ==============================================================================

@dataclass
class HierarchicalReasoningModel_ACTV1InnerCarry:
    z_H: torch.Tensor
    z_L: torch.Tensor

@dataclass
class HierarchicalReasoningModel_ACTV1Carry:
    inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
    steps: torch.Tensor
    halted: torch.Tensor
    current_data: Dict[str, torch.Tensor]

class HierarchicalReasoningModel_ACTV1Config(BaseModel):
    batch_size: int
    seq_len: int
    puzzle_emb_ndim: int
    num_puzzle_identifiers: int
    vocab_size: int
    H_cycles: int
    L_cycles: int
    H_layers: int
    L_layers: int
    hidden_size: int
    expansion: float
    num_heads: int
    pos_encodings: str
    rms_norm_eps: float = 1e-5
    rope_theta: float = 10000.0
    halt_max_steps: int
    halt_exploration_prob: float
    forward_dtype: str = "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float32"

class HierarchicalReasoningModel_ACTV1Block(nn.Module):
    def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
        super().__init__()
        self.self_attn = Attention(
            hidden_size=config.hidden_size,
            head_dim=config.hidden_size // config.num_heads,
            num_heads=config.num_heads,
            num_key_value_heads=config.num_heads,
            causal=False
        )
        self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
        self.norm_eps = config.rms_norm_eps

    def forward(self, cos_sin: Optional[CosSin], hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
        hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
        return hidden_states

class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
    def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
        super().__init__()
        self.layers = nn.ModuleList(layers)

    def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
        hidden_states = hidden_states + input_injection
        for layer in self.layers:
            hidden_states = layer(hidden_states=hidden_states, **kwargs)
        return hidden_states

class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
    def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
        super().__init__()
        self.config = config
        self.forward_dtype = getattr(torch, self.config.forward_dtype)
        self.embed_scale = math.sqrt(self.config.hidden_size)
        embed_init_std = 1.0 / self.embed_scale
        self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
        self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
        self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
        self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
        
        if self.config.puzzle_emb_ndim > 0:
            self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)

        if self.config.pos_encodings == "rope":
            self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
        elif self.config.pos_encodings == "learned":
            self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
        
        self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _ in range(self.config.H_layers)])
        self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _ in range(self.config.L_layers)])
        
        self.register_buffer('H_init', trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1))
        self.register_buffer('L_init', trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1))
        
        with torch.no_grad():
            self.q_head.weight.zero_()
            if self.q_head.bias is not None:
                self.q_head.bias.fill_(-5)

    def _input_embeddings(self, input_tensor: torch.Tensor, puzzle_identifiers: torch.Tensor):
        embedding = self.embed_tokens(input_tensor.to(torch.int32))
        if self.config.puzzle_emb_ndim > 0 and hasattr(self, 'puzzle_emb'):
            puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
            pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
            if pad_count > 0:
                puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
            embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
        if self.config.pos_encodings == "learned":
            embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
        return self.embed_scale * embedding

    def empty_carry(self, batch_size: int):
        device = self.H_init.device
        return HierarchicalReasoningModel_ACTV1InnerCarry(
            z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype, device=device),
            z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype, device=device),
        )
        
    def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
        return HierarchicalReasoningModel_ACTV1InnerCarry(
            z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
            z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
        )

    def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        seq_info = dict(cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None)
        input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
        z_H, z_L = carry.z_H, carry.z_L
        
        with torch.no_grad():
            for _ in range(self.config.H_cycles):
                for _ in range(self.config.L_cycles -1):
                    z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
                z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) # last L_cycle for H
                z_H = self.H_level(z_H, z_L, **seq_info)

        z_L = self.L_level(z_L.detach(), z_H.detach() + input_embeddings, **seq_info)
        z_H = self.H_level(z_H.detach(), z_L, **seq_info)

        new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
        output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
        q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
        return new_carry, output, (q_logits[..., 0], q_logits[..., 1])

class HierarchicalReasoningModel_ACTV1(nn.Module):
    def __init__(self, config_dict: dict):
        super().__init__()
        self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
        self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)

    @property
    def puzzle_emb(self):
        return self.inner.puzzle_emb

    def initial_carry(self, batch: Dict[str, torch.Tensor]):
        batch_size = batch["inputs"].shape[0]
        device = batch["inputs"].device
        return HierarchicalReasoningModel_ACTV1Carry(
            inner_carry=self.inner.empty_carry(batch_size),
            steps=torch.zeros((batch_size,), dtype=torch.int32, device=device),
            halted=torch.ones((batch_size,), dtype=torch.bool, device=device),
            current_data={k: torch.empty_like(v) for k, v in batch.items()}
        )
        
    def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
        new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
        new_steps = torch.where(carry.halted, 0, carry.steps)
        new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
        
        new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
        outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
        
        with torch.no_grad():
            new_steps += 1
            is_last_step = new_steps >= self.config.halt_max_steps
            halted = is_last_step
            if self.training and (self.config.halt_max_steps > 1):
                halted = halted | (q_halt_logits > q_continue_logits)
                min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
                halted = halted & (new_steps >= min_halt_steps)
                next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
                outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
        
        return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs


# ==============================================================================
# 5. Loss Head
# ==============================================================================
IGNORE_LABEL_ID = -100

def s(x, epsilon=1e-30):
    return torch.where(x < 0, 1 / (1 - x + epsilon), x + 1)

def log_stablemax(x, dim=-1):
    s_x = s(x)
    return torch.log(s_x / torch.sum(s_x, dim=dim, keepdim=True))

def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
    logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
    valid_mask = labels != ignore_index
    transformed_labels = torch.where(valid_mask, labels, 0)
    prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
    return -torch.where(valid_mask, prediction_logprobs, 0)

def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
    return F.cross_entropy(logits.to(torch.float32).flatten(0, 1), labels.to(torch.long).flatten(), ignore_index=ignore_index, reduction="none").view(labels.shape)

class ACTLossHead(nn.Module):
    def __init__(self, model: nn.Module, loss_type: str):
        super().__init__()
        self.model = model
        self.loss_fn = globals()[loss_type]
        
    def initial_carry(self, *args, **kwargs):
        return self.model.initial_carry(*args, **kwargs)

    def forward(self, carry: Any, batch: Dict[str, torch.Tensor], return_keys: Sequence[str] = ()):
        new_carry, outputs = self.model(carry, batch)
        labels = new_carry.current_data["labels"]

        with torch.no_grad():
            mask = labels != IGNORE_LABEL_ID
            loss_counts = mask.sum(-1)
            loss_divisor = loss_counts.clamp_min(1)
            is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
            seq_is_correct = is_correct.sum(-1) == loss_counts
            valid_metrics = new_carry.halted & (loss_counts > 0)
            metrics = {
                "count": valid_metrics.sum(),
                "accuracy": torch.where(valid_metrics, is_correct.float().sum(-1) / loss_divisor, 0).sum(),
                "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
                "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
                "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
            }

        lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID).sum(-1) / loss_divisor).sum()
        q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
        metrics.update({"lm_loss": lm_loss.detach(), "q_halt_loss": q_halt_loss.detach()})

        q_continue_loss = torch.tensor(0.0, device=lm_loss.device)
        if "target_q_continue" in outputs:
            q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
            metrics["q_continue_loss"] = q_continue_loss.detach()

        total_loss = lm_loss + 0.5 * (q_halt_loss + q_continue_loss)
        detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
        return new_carry, total_loss, metrics, detached_outputs, new_carry.halted.all()


# ==============================================================================
# 6. Sudoku Dataset
# ==============================================================================
class SudokuDataset(Dataset):
    def __init__(self, puzzles, solutions):
        self.puzzles = puzzles
        self.solutions = solutions

    def __len__(self):
        return len(self.puzzles)

    def __getitem__(self, idx):
        # Flatten and convert to long tensor
        # Original data: 0 for blank, 1-9 for digits.
        # Model expects vocab: 0 for PAD, 1 for blank, 2-10 for digits 1-9.
        # So, we add 1 to all values.
        puzzle = torch.from_numpy(self.puzzles[idx].flatten().astype(np.int64)) + 1
        solution = torch.from_numpy(self.solutions[idx].flatten().astype(np.int64)) + 1
        return {"inputs": puzzle, "labels": solution}

# ==============================================================================
# 7. Main Training & Evaluation Loop
# ==============================================================================
def main():
    cfg = Config()
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    
    # --- 1. Setup ---
    writer = SummaryWriter(f'runs/{cfg.run_name}')
    os.makedirs(cfg.checkpoint_path, exist_ok=True)
    
    print("===================================================")
    print(f"Starting run: {cfg.run_name}")
    print(f"Device: {cfg.device}")
    print(f"Configuration: {cfg}")
    print("===================================================")

    # --- 2. Data Loading ---
    with np.load(cfg.data_path) as data:
        puzzles = data['puzzles']
        solutions = data['solutions']
    
    # Shuffle data before splitting
    indices = np.arange(len(puzzles))
    np.random.shuffle(indices)
    puzzles = puzzles[indices]
    solutions = solutions[indices]

    train_puzzles, val_puzzles = puzzles[:cfg.num_train_samples], puzzles[cfg.num_train_samples:]
    train_solutions, val_solutions = solutions[:cfg.num_train_samples], solutions[cfg.num_train_samples:]
    
    train_dataset = SudokuDataset(train_puzzles, train_solutions)
    val_dataset = SudokuDataset(val_puzzles, val_solutions)
    
    train_loader = DataLoader(train_dataset, batch_size=cfg.global_batch_size, shuffle=True, num_workers=0,drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=0,drop_last=True)
    
    # --- 3. Model & Optimizer ---
    model_cfg_dict = {
        "batch_size": cfg.global_batch_size,
        "seq_len": 81,
        "puzzle_emb_ndim": cfg.puzzle_emb_ndim,
        "num_puzzle_identifiers": 1, # Only one task: Sudoku
        "vocab_size": 11, # 0=PAD, 1=blank, 2-10 for digits 1-9
        "H_cycles": cfg.H_cycles,
        "L_cycles": cfg.L_cycles,
        "H_layers": cfg.H_layers,
        "L_layers": cfg.L_layers,
        "hidden_size": cfg.hidden_size,
        "expansion": cfg.expansion,
        "num_heads": cfg.num_heads,
        "pos_encodings": cfg.pos_encodings,
        "halt_max_steps": cfg.halt_max_steps,
        "halt_exploration_prob": cfg.halt_exploration_prob,
    }
    
    base_model = HierarchicalReasoningModel_ACTV1(model_cfg_dict).to(cfg.device)
    model = ACTLossHead(base_model, loss_type=cfg.loss_type).to(cfg.device)

    # For Sudoku, sparse embedding optimizer is not really needed as there's only one puzzle type.
    # We use a standard optimizer for all parameters.
    optimizer = AdamW(
        model.parameters(),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
        betas=(cfg.beta1, cfg.beta2)
    )

    total_steps = len(train_loader) * cfg.epochs
    
    # --- 4. Training Loop ---
    global_step = 0
    train_carry = None

    for epoch in range(cfg.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.epochs}")
        for batch in pbar:
            batch = {k: v.to(cfg.device) for k, v in batch.items()}
            # Sudoku is a single task, so puzzle_identifiers are all zeros
            batch["puzzle_identifiers"] = torch.zeros(batch["inputs"].shape[0], dtype=torch.long, device=cfg.device)

            if train_carry is None:
                train_carry = model.initial_carry(batch)

            # Update learning rate
            lr = cfg.lr
            if global_step < cfg.lr_warmup_steps:
                lr = cfg.lr * (global_step / cfg.lr_warmup_steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            
            optimizer.zero_grad()
            
            train_carry, loss, metrics, _, _ = model(carry=train_carry, batch=batch)
            
            loss.backward()
            optimizer.step()
            
            # Logging
            count = metrics['count'].item()
            if count > 0:
                for k, v in metrics.items():
                    if k != 'count':
                        writer.add_scalar(f"train/{k}", v.item() / count, global_step)
                writer.add_scalar("train/loss", loss.item() / count, global_step)
                writer.add_scalar("train/learning_rate", lr, global_step)
            
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
            global_step += 1
            
        # --- 5. Evaluation ---
        if (epoch + 1) % cfg.eval_interval == 0:
            model.eval()
            val_metrics = {}
            with torch.no_grad():
                val_carry = None
                for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}"):
                    batch = {k: v.to(cfg.device) for k, v in batch.items()}
                    batch["puzzle_identifiers"] = torch.zeros(batch["inputs"].shape[0], dtype=torch.long, device=cfg.device)

                    if val_carry is None:
                         val_carry = model.initial_carry(batch)

                    while True:
                        val_carry, _, metrics, _, all_finished = model(carry=val_carry, batch=batch)
                        for k, v in metrics.items():
                            val_metrics[k] = val_metrics.get(k, 0) + v.item()
                        if all_finished:
                            break
            
            # Log validation metrics
            count = val_metrics.pop('count')
            if count > 0:
                print(f"\n--- Validation Results (Epoch {epoch+1}) ---")
                for k, v in val_metrics.items():
                    avg_v = v / count
                    writer.add_scalar(f"val/{k}", avg_v, global_step)
                    print(f"  {k}: {avg_v:.4f}")
                print("----------------------------------------")
                if val_metrics.get('exact_accuracy', 0) / count >= 1.0:
                    print("Early stopping as exact accuracy reached 100%")
                    # Save the best model
                    best_model_file = os.path.join(cfg.checkpoint_path, "best_model.pt")
                    torch.save(model.state_dict(), best_model_file)
                    print(f"Best model saved to {best_model_file}")
                    break
        
    # Save checkpoint
    checkpoint_file = os.path.join(cfg.checkpoint_path, f"epoch_{epoch+1}.pt")
    torch.save(model.state_dict(), checkpoint_file)
    print(f"Checkpoint saved to {checkpoint_file}")

    writer.close()
    print("Training finished.")

In [None]:
# ==============================================================================
# 1. Configuration
# ==============================================================================
class Config:
    # Data
    data_path: str = "sudoku_dataset_2100.npz"
    num_train_samples: int = 2000
    num_val_samples: int = 100
    
    # Training Hyperparameters
    global_batch_size: int = 1000
    epochs: int = 2000 # Reduced for faster demonstration
    eval_interval: int = 5 # Evaluate every 5 epochs
    
    lr: float = 7e-5
    lr_min_ratio: float = 1.0
    lr_warmup_steps: int = 100
    
    weight_decay: float = 1.0
    beta1: float = 0.9
    beta2: float = 0.95
    
    # Puzzle embedding (not used for Sudoku but kept for compatibility)
    puzzle_emb_lr: float = 7e-5
    puzzle_emb_weight_decay: float = 1.0

    # Model Architecture (from hrm_v1.yaml)
    halt_exploration_prob: float = 0.1
    halt_max_steps: int = 8
    H_cycles: int = 2
    L_cycles: int = 2
    H_layers: int = 4
    L_layers: int = 4
    hidden_size: int = 384
    num_heads: int = 8
    expansion: int = 4
    puzzle_emb_ndim: int = 384
    pos_encodings: str = "rope" # 'rope' or 'learned'
    
    # Loss
    loss_type: str = "softmax_cross_entropy" # 'softmax_cross_entropy' or 'stablemax_cross_entropy'

    # System
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Logging
    run_name: str = f"sudoku_{coolname.generate_slug(2)}"
    checkpoint_path: str = f"checkpoints/{run_name}"

In [None]:
# Train
if __name__ == "__main__":
    main()

In [4]:
cfg = Config()
model_cfg_dict = {
        "batch_size": cfg.global_batch_size,
        "seq_len": 81,
        "puzzle_emb_ndim": cfg.puzzle_emb_ndim,
        "num_puzzle_identifiers": 1, # Only one task: Sudoku
        "vocab_size": 11, # 0=PAD, 1=blank, 2-10 for digits 1-9
        "H_cycles": cfg.H_cycles,
        "L_cycles": cfg.L_cycles,
        "H_layers": cfg.H_layers,
        "L_layers": cfg.L_layers,
        "hidden_size": cfg.hidden_size,
        "expansion": cfg.expansion,
        "num_heads": cfg.num_heads,
        "pos_encodings": cfg.pos_encodings,
        "halt_max_steps": cfg.halt_max_steps,
        "halt_exploration_prob": cfg.halt_exploration_prob,
    }

In [5]:
base_model = HierarchicalReasoningModel_ACTV1(model_cfg_dict).to(cfg.device)
model = ACTLossHead(base_model, loss_type=cfg.loss_type).to(cfg.device)
model.load_state_dict(torch.load("checkpoints/sudoku_adaptable-leech/epoch_500.pt"))

<All keys matched successfully>

In [None]:
import torch
import torch.onnx
from torch import nn
from typing import Tuple

class InferenceModel(nn.Module):
    def __init__(self, inner_model: HierarchicalReasoningModel_ACTV1_Inner):
        super().__init__()
        self.inner = inner_model

    def forward(self, 
                inputs: torch.Tensor, 
                puzzle_identifiers: torch.Tensor,
                z_H_in: torch.Tensor, 
                z_L_in: torch.Tensor
               ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        
        carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H_in, z_L=z_L_in)
        batch = { "inputs": inputs, "puzzle_identifiers": puzzle_identifiers }
        new_carry, logits, (q_halt, q_continue) = self.inner(carry, batch)
        
        return logits, new_carry.z_H, new_carry.z_L

def export_to_onnx():
    cfg = Config()
    batch_size = 1
    cfg.device = 'cpu'
    CHECKPOINT_PATH = f"./checkpoints/sudoku_able-junglefowl/epoch_500.pt"
    ONNX_MODEL_PATH = "sudoku_hrm_inner_fp16.onnx"

    print(f"Loading checkpoint from: {CHECKPOINT_PATH}")

    model_cfg_dict = {
        "batch_size": batch_size,
        "seq_len": 81,
        "puzzle_emb_ndim": cfg.puzzle_emb_ndim,
        "num_puzzle_identifiers": 1,
        "vocab_size": 11,
        "H_cycles": cfg.H_cycles,
        "L_cycles": cfg.L_cycles,
        "H_layers": cfg.H_layers,
        "L_layers": cfg.L_layers,
        "hidden_size": cfg.hidden_size,
        "expansion": cfg.expansion,
        "num_heads": cfg.num_heads,
        "pos_encodings": cfg.pos_encodings,
        "halt_max_steps": cfg.halt_max_steps,
        "halt_exploration_prob": cfg.halt_exploration_prob,
        "forward_dtype": "float16",
    }

    inner_model = HierarchicalReasoningModel_ACTV1(model_cfg_dict).inner.to(cfg.device)
    
    state_dict = torch.load(CHECKPOINT_PATH, map_location=cfg.device)
    inner_state_dict = {
        k.replace('model.inner.', ''): v 
        for k, v in state_dict.items() 
        if k.startswith('model.inner.')
    }
    inner_model.load_state_dict(inner_state_dict)

    inner_model = inner_model.to(torch.float16)
    inner_model.eval()

    inference_model = InferenceModel(inner_model)
    
    puzzle_emb_len = inner_model.puzzle_emb_len
    seq_len_with_emb = 81 + puzzle_emb_len

    dummy_inputs = torch.randint(1, 11, (batch_size, 81), dtype=torch.long, device=cfg.device)
    dummy_puzzle_ids = torch.zeros(batch_size, dtype=torch.long, device=cfg.device)
    
    dummy_z_H_in = inner_model.H_init.unsqueeze(0).expand(batch_size, seq_len_with_emb, -1)
    dummy_z_L_in = inner_model.L_init.unsqueeze(0).expand(batch_size, seq_len_with_emb, -1)
    
    assert dummy_z_H_in.dtype == torch.float16, f"Expected float16 but got {dummy_z_H_in.dtype}"
    assert dummy_z_L_in.dtype == torch.float16, f"Expected float16 but got {dummy_z_L_in.dtype}"

    print(f"Exporting model to {ONNX_MODEL_PATH}...")
    torch.onnx.export(
        inference_model,
        (dummy_inputs, dummy_puzzle_ids, dummy_z_H_in, dummy_z_L_in),
        ONNX_MODEL_PATH,
        input_names=['inputs', 'puzzle_identifiers', 'z_H_in', 'z_L_in'],
        output_names=['logits', 'z_H_out', 'z_L_out'],
        opset_version=14,
        do_constant_folding=True,
        dynamic_axes={
            'inputs': {0: 'batch_size'},
            'puzzle_identifiers': {0: 'batch_size'},
            'z_H_in': {0: 'batch_size'},
            'z_L_in': {0: 'batch_size'},
            'logits': {0: 'batch_size'},
            'z_H_out': {0: 'batch_size'},
            'z_L_out': {0: 'batch_size'},
        }
    )
    print("Export complete.")

if __name__ == '__main__':
    export_to_onnx()