In [6]:
import os
import sys
import torch
import yaml
from typing import Dict, Any

sys.path.append(os.getcwd())

from utils.functions import load_model_class

# ============================================================================
# Configuration
# ============================================================================

CHECKPOINT_PATH = "/home/zakarianarjis/workspace/TinyRecursiveModels/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku/step_65100"
CHECKPOINT_PATH = "/home/zakarianarjis/workspace/TinyRecursiveModels/checkpoints/single_z/step_65100"
DATA_PATH = "data/sudoku-extreme-1k-aug-1000"  # Need this to get dataset metadata

# ============================================================================
# Helper Functions
# ============================================================================

def strip_compiled_prefix(state_dict):
    """Remove _orig_mod. prefix from compiled model state dict."""
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('_orig_mod.'):
            new_key = key.replace('_orig_mod.', '')
            new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value
    return new_state_dict

def get_dataset_metadata(data_path, split="test"):
    """Load dataset metadata without loading the full dataset."""
    import json
    from dataset.common import PuzzleDatasetMetadata
    
    metadata_file = os.path.join(data_path, split, "dataset.json")
    with open(metadata_file, "r") as f:
        metadata = PuzzleDatasetMetadata(**json.load(f))
    
    return metadata

def load_model(checkpoint_path, data_path=None, device="cuda"):
    """Load model from checkpoint."""
    
    # Load config
    config_file = os.path.join(os.path.dirname(checkpoint_path), "all_config.yaml")
    with open(config_file, "r") as f:
        full_config = yaml.safe_load(f)
    
    # Extract model config
    arch_config = full_config["arch"]
    model_name = arch_config["name"]
    loss_config = arch_config["loss"]
    loss_name = loss_config["name"]
    
    # Build model config
    model_cfg = {k: v for k, v in arch_config.items() if k not in ["name", "loss"]}
    model_cfg["causal"] = False
    
    # Get dataset metadata if data_path provided, otherwise use defaults from config
    if data_path is not None:
        metadata = get_dataset_metadata(data_path)
        model_cfg["vocab_size"] = metadata.vocab_size
        model_cfg["seq_len"] = metadata.seq_len
        model_cfg["num_puzzle_identifiers"] = metadata.num_puzzle_identifiers
        model_cfg["batch_size"] = full_config.get("global_batch_size", 64)
    else:
        # Try to infer from training config
        train_data_path = full_config["data_paths"][0]
        if os.path.exists(train_data_path):
            metadata = get_dataset_metadata(train_data_path, split="train")
            model_cfg["vocab_size"] = metadata.vocab_size
            model_cfg["seq_len"] = metadata.seq_len
            model_cfg["num_puzzle_identifiers"] = metadata.num_puzzle_identifiers
            model_cfg["batch_size"] = full_config.get("global_batch_size", 64)
        else:
            raise ValueError(f"Cannot find dataset at {train_data_path}. Please provide data_path argument.")
    
    # Load classes
    model_cls = load_model_class(model_name)
    loss_head_cls = load_model_class(loss_name)
    loss_kwargs = {k: v for k, v in loss_config.items() if k != "name"}
    
    # Create model
    with torch.device(device):
        model = model_cls(model_cfg)
        model = loss_head_cls(model, **loss_kwargs)
        
        # Load weights
        state_dict = torch.load(checkpoint_path, map_location=device)
        state_dict = strip_compiled_prefix(state_dict)
        model.load_state_dict(state_dict, strict=True)
        model.eval()
    
    return model, full_config, model_cfg

def inspect_model_structure(model):
    """Print detailed model structure."""
    print("\n" + "=" * 80)
    print("DETAILED MODEL STRUCTURE")
    print("=" * 80)
    
    print("\nLayer breakdown:")
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # Only leaf modules
            num_params = sum(p.numel() for p in module.parameters())
            if num_params > 0:
                print(f"  {name}: {module.__class__.__name__} ({num_params:,} params)")

def inspect_state_dict(model):
    """Print state dict keys and shapes."""
    print("\n" + "=" * 80)
    print("MODEL STATE DICT")
    print("=" * 80)
    
    for name, param in model.state_dict().items():
        print(f"  {name}: {tuple(param.shape)}")

# ============================================================================
# Load Model
# ============================================================================

model, full_config, model_cfg = load_model(CHECKPOINT_PATH, DATA_PATH)

print("=" * 80)
print("MODEL LOADED SUCCESSFULLY")
print("=" * 80)

# Print config
print("\nModel Configuration:")
for key, value in model_cfg.items():
    print(f"  {key}: {value}")

print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Print model structure
print("\n" + "=" * 80)
print("MODEL STRUCTURE (HIGH LEVEL)")
print("=" * 80)
print(model)

# ============================================================================
# Input/Output Shapes
# ============================================================================

print("\n" + "=" * 80)
print("INPUT/OUTPUT TENSOR SHAPES")
print("=" * 80)

batch_size = 1
seq_len = model_cfg["seq_len"]
vocab_size = model_cfg["vocab_size"]
num_puzzle_identifiers = model_cfg["num_puzzle_identifiers"]

print(f"\nBatch structure:")
print(f"  inputs: torch.Tensor of shape ({batch_size}, {seq_len})")
print(f"    - dtype: torch.int32 or torch.long")
print(f"    - values: token IDs in range [0, {vocab_size - 1}]")
print(f"\n  labels: torch.Tensor of shape ({batch_size}, {seq_len})")
print(f"    - dtype: torch.int32 or torch.long")
print(f"    - values: token IDs in range [0, {vocab_size - 1}] or -100 (ignore)")
print(f"\n  puzzle_identifiers: torch.Tensor of shape ({batch_size},)")
print(f"    - dtype: torch.int32 or torch.long")
print(f"    - values: puzzle IDs in range [0, {num_puzzle_identifiers - 1}]")

print(f"\nOutput structure:")
print(f"  logits: torch.Tensor of shape ({batch_size}, {seq_len}, {vocab_size})")
print(f"    - unnormalized logits for each token position")

# ============================================================================
# Create a dummy batch for testing
# ============================================================================

print("\n" + "=" * 80)
print("CREATING DUMMY BATCH FOR TESTING")
print("=" * 80)

dummy_batch = {
    "inputs": torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda"),
    "labels": torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda"),
    "puzzle_identifiers": torch.randint(0, num_puzzle_identifiers, (batch_size,), device="cuda"),
}

# print(f"\nDummy batch created:")
# for key, value in dummy_batch.items():
#     print(f"  {key}: {value.shape}, dtype={value.dtype}, device={value.device}")

# # Initialize carry
# print("\nInitializing carry...")
# with torch.device("cuda"):
#     with torch.no_grad():
#         carry = model.initial_carry(dummy_batch)
#         print(f"Carry initialized: {type(carry)}")
#         print(f"  inner_carry.z_H: {carry.inner_carry.z_H.shape}")
#         print(f"  inner_carry.z_L: {carry.inner_carry.z_L.shape}")
#         print(f"  steps: {carry.steps.shape}")
#         print(f"  halted: {carry.halted.shape}")

# # Run one forward pass
# print("\nRunning forward pass...")
# with torch.no_grad():
#     carry, loss, metrics, preds, all_finish = model(
#         carry=carry,
#         batch=dummy_batch,
#         return_keys=["logits"]
#     )
    
# print(f"\nForward pass complete!")
# print(f"  loss: {loss.item():.4f}")
# print(f"  all_finish: {all_finish}")
# print(f"  logits shape: {preds['logits'].shape}")
# print(f"  metrics: {list(metrics.keys())}")

# # Get predictions
# pred_tokens = torch.argmax(preds['logits'], dim=-1)
# print(f"  predicted tokens shape: {pred_tokens.shape}")

print("\n" + "=" * 80)
print("READY FOR INSPECTION!")
print("=" * 80)
print("\nAvailable objects:")
print("  - model: the full model (ACTLossHead)")
print("  - model.model: the ACT wrapper")
print("  - model.model.inner: the inner transformer model")
print("  - full_config: the full training config")
print("  - model_cfg: the model config")
print("  - dummy_batch: a sample batch")
print("  - carry: the model's recurrent state")
print("  - preds: predictions from forward pass")
print("\nUseful inspection functions:")
print("  - inspect_model_structure(model): detailed layer breakdown")
print("  - inspect_state_dict(model): all parameter shapes")

MODEL LOADED SUCCESSFULLY

Model Configuration:
  H_cycles: 3
  H_layers: 0
  L_cycles: 6
  L_layers: 2
  expansion: 4
  forward_dtype: bfloat16
  halt_exploration_prob: 0.1
  halt_max_steps: 16
  hidden_size: 512
  mlp_t: True
  no_ACT_continue: True
  num_heads: 8
  pos_encodings: none
  puzzle_emb_len: 16
  puzzle_emb_ndim: 512
  causal: False
  vocab_size: 11
  seq_len: 81
  num_puzzle_identifiers: 1
  batch_size: 768

Total parameters: 5,028,866
Trainable parameters: 5,028,866

MODEL STRUCTURE (HIGH LEVEL)
ACTLossHead(
  (model): TinyRecursiveReasoningModel_ACTV1(
    (inner): TinyRecursiveReasoningModel_ACTV1_Inner(
      (embed_tokens): CastedEmbedding()
      (lm_head): CastedLinear()
      (q_head): CastedLinear()
      (puzzle_emb): CastedSparseEmbedding()
      (L_level): TinyRecursiveReasoningModel_ACTV1ReasoningModule(
        (layers): ModuleList(
          (0-1): 2 x TinyRecursiveReasoningModel_ACTV1Block(
            (mlp_t): SwiGLU(
              (gate_up_proj): Casted

In [8]:
model

ACTLossHead(
  (model): TinyRecursiveReasoningModel_ACTV1(
    (inner): TinyRecursiveReasoningModel_ACTV1_Inner(
      (embed_tokens): CastedEmbedding()
      (lm_head): CastedLinear()
      (q_head): CastedLinear()
      (puzzle_emb): CastedSparseEmbedding()
      (L_level): TinyRecursiveReasoningModel_ACTV1ReasoningModule(
        (layers): ModuleList(
          (0-1): 2 x TinyRecursiveReasoningModel_ACTV1Block(
            (mlp_t): SwiGLU(
              (gate_up_proj): CastedLinear()
              (down_proj): CastedLinear()
            )
            (mlp): SwiGLU(
              (gate_up_proj): CastedLinear()
              (down_proj): CastedLinear()
            )
          )
        )
      )
    )
  )
)

In [213]:
import torch
from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig
from torch.utils.data import DataLoader
import torch

def get_sudoku_embedding(model: torch.nn.Module,
                         dataset_path: str = "/home/zakarianarjis/workspace/TinyRecursiveModels/data/sudoku-extreme-1k-aug-1000"):
    """
    Loads one sample from the Sudoku dataset and returns the combined input embedding
    (token embeddings + puzzle embeddings) suitable for the ACT model.
    """
    # Minimal config compatible with PuzzleDataset
    dataset_config = PuzzleDatasetConfig(
        seed=42,
        dataset_paths=[dataset_path],
        rank=0,
        num_replicas=1,
        epochs_per_iter=1,
        batch_size=1,
        test_set_mode=False,
        global_batch_size=1
    )

    # Load dataset and dataloader
    dataset = PuzzleDataset(dataset_config, split="train")
    dataloader = DataLoader(dataset, batch_size=None)  # batch_size=None for iterable datasets

    # Take the first sample
    batch = next(iter(dataloader))
    sample_dict = batch[1]  # the dict with 'inputs', 'labels', 'puzzle_identifiers'

    token_ids = sample_dict['inputs'].cuda()                 # [seq_len]
    puzzle_ids = sample_dict['puzzle_identifiers'].cuda()    # [num_puzzle_identifiers]

    # Ensure batch dimension
    token_ids = token_ids  # [1, seq_len]
    puzzle_ids = puzzle_ids.unsqueeze(0) # [1, num_puzzle_identifiers]
    print(token_ids.shape, puzzle_ids.shape)
    # --- Construct combined embedding ---
    # This uses the _input_embeddings method which handles puzzle embeddings internally
    full_input_embed = model.model.inner._input_embeddings(
        token_ids, puzzle_identifiers=puzzle_ids
    )  # [1, seq_len + puzzle_emb_len, hidden_size]

    return full_input_embed.squeeze(0), token_ids.squeeze(0)  # [seq_len + puzzle_emb_len, hidden_size], [seq_len]

model.eval()
embedding, ids = get_sudoku_embedding(model, dataset_path="/home/zakarianarjis/workspace/TinyRecursiveModels/data/sudoku-extreme-1k-aug-1000")

torch.Size([1, 81]) torch.Size([1, 1])


In [218]:
embedding.shape

torch.Size([97, 512])

In [217]:
# Test L_level with dummy input
import torch

L_level = model.model.inner.L_level
hidden_size = model_cfg["hidden_size"]
seq_len_with_emb = model_cfg["seq_len"] + model.model.inner.puzzle_emb_len
batch_size_test = 1

hidden_states = 2 * torch.randn(batch_size_test, seq_len_with_emb, hidden_size, device="cuda", dtype=torch.bfloat16) 
input_injection = embedding.clone()
cos_sin = model.model.inner.rotary_emb() if hasattr(model.model.inner, 'rotary_emb') else None

# with torch.no_grad():
#     output_1 = L_level(hidden_states=hidden_states ,cos_sin=cos_sin)
#     output_2 = L_level(hidden_states=output_1 ,cos_sin=cos_sin)

In [228]:
output = hidden_states.clone() + input_injection
for i in range(10):
    previous_output = output.clone()
    with torch.no_grad():
        output = L_level(hidden_states=output, cos_sin=cos_sin)  
print(output-previous_output)

tensor([[[-0.2207, -0.0723,  0.4297,  ..., -0.3320, -0.1484, -0.1719],
         [-0.0996, -0.0889,  0.3516,  ..., -0.0938, -0.0391, -0.1680],
         [ 0.0195, -0.0391, -0.1924,  ...,  0.0073, -0.0059, -0.0039],
         ...,
         [ 0.1406,  0.1094, -0.1172,  ...,  0.1758,  0.0195,  0.0547],
         [-0.0117,  0.0195, -0.1309,  ...,  0.0859,  0.0742, -0.2031],
         [-0.0449, -0.0938,  0.0469,  ...,  0.1328, -0.1367,  0.1602]]],
       device='cuda:0', dtype=torch.bfloat16)


In [1]:
def _find_multiple(a, b):
    return (-(a // -b)) * b
a= 1365
b = 256
print(_find_multiple(a, b))

1536


In [None]:
# ARC/pass@1: 0.42875
# ARC/pass@2: 0.47625
# ARC/pass@5: 0.50250
# ARC/pass@10: 0.53375
# ARC/pass@100: 0.60625
# ARC/pass@1000: 0.61125