In [1]:
import os
import sys
import torch
import yaml
from typing import Dict, Any

sys.path.append(os.getcwd())

from utils.functions import load_model_class

# ============================================================================
# Configuration
# ============================================================================

CHECKPOINT_PATH = "/home/zakarianarjis/workspace/TinyRecursiveModels/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku/step_65100"
DATA_PATH = "data/sudoku-extreme-1k-aug-1000"  # Need this to get dataset metadata

# ============================================================================
# Helper Functions
# ============================================================================

def strip_compiled_prefix(state_dict):
    """Remove _orig_mod. prefix from compiled model state dict."""
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('_orig_mod.'):
            new_key = key.replace('_orig_mod.', '')
            new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value
    return new_state_dict

def get_dataset_metadata(data_path, split="test"):
    """Load dataset metadata without loading the full dataset."""
    import json
    from dataset.common import PuzzleDatasetMetadata
    
    metadata_file = os.path.join(data_path, split, "dataset.json")
    with open(metadata_file, "r") as f:
        metadata = PuzzleDatasetMetadata(**json.load(f))
    
    return metadata

def load_model(checkpoint_path, data_path=None, device="cuda"):
    """Load model from checkpoint."""
    
    # Load config
    config_file = os.path.join(os.path.dirname(checkpoint_path), "all_config.yaml")
    with open(config_file, "r") as f:
        full_config = yaml.safe_load(f)
    
    # Extract model config
    arch_config = full_config["arch"]
    model_name = arch_config["name"]
    loss_config = arch_config["loss"]
    loss_name = loss_config["name"]
    
    # Build model config
    model_cfg = {k: v for k, v in arch_config.items() if k not in ["name", "loss"]}
    model_cfg["causal"] = False
    
    # Get dataset metadata if data_path provided, otherwise use defaults from config
    if data_path is not None:
        metadata = get_dataset_metadata(data_path)
        model_cfg["vocab_size"] = metadata.vocab_size
        model_cfg["seq_len"] = metadata.seq_len
        model_cfg["num_puzzle_identifiers"] = metadata.num_puzzle_identifiers
        model_cfg["batch_size"] = full_config.get("global_batch_size", 64)
    else:
        # Try to infer from training config
        train_data_path = full_config["data_paths"][0]
        if os.path.exists(train_data_path):
            metadata = get_dataset_metadata(train_data_path, split="train")
            model_cfg["vocab_size"] = metadata.vocab_size
            model_cfg["seq_len"] = metadata.seq_len
            model_cfg["num_puzzle_identifiers"] = metadata.num_puzzle_identifiers
            model_cfg["batch_size"] = full_config.get("global_batch_size", 64)
        else:
            raise ValueError(f"Cannot find dataset at {train_data_path}. Please provide data_path argument.")
    
    # Load classes
    model_cls = load_model_class(model_name)
    loss_head_cls = load_model_class(loss_name)
    loss_kwargs = {k: v for k, v in loss_config.items() if k != "name"}
    
    # Create model
    with torch.device(device):
        model = model_cls(model_cfg)
        model = loss_head_cls(model, **loss_kwargs)
        
        # Load weights
        state_dict = torch.load(checkpoint_path, map_location=device)
        state_dict = strip_compiled_prefix(state_dict)
        model.load_state_dict(state_dict, strict=True)
        model.eval()
    
    return model, full_config, model_cfg

def inspect_model_structure(model):
    """Print detailed model structure."""
    print("\n" + "=" * 80)
    print("DETAILED MODEL STRUCTURE")
    print("=" * 80)
    
    print("\nLayer breakdown:")
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # Only leaf modules
            num_params = sum(p.numel() for p in module.parameters())
            if num_params > 0:
                print(f"  {name}: {module.__class__.__name__} ({num_params:,} params)")

def inspect_state_dict(model):
    """Print state dict keys and shapes."""
    print("\n" + "=" * 80)
    print("MODEL STATE DICT")
    print("=" * 80)
    
    for name, param in model.state_dict().items():
        print(f"  {name}: {tuple(param.shape)}")

# ============================================================================
# Load Model
# ============================================================================

model, full_config, model_cfg = load_model(CHECKPOINT_PATH, DATA_PATH)

print("=" * 80)
print("MODEL LOADED SUCCESSFULLY")
print("=" * 80)

# Print config
print("\nModel Configuration:")
for key, value in model_cfg.items():
    print(f"  {key}: {value}")

print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Print model structure
print("\n" + "=" * 80)
print("MODEL STRUCTURE (HIGH LEVEL)")
print("=" * 80)
print(model)

# ============================================================================
# Input/Output Shapes
# ============================================================================

print("\n" + "=" * 80)
print("INPUT/OUTPUT TENSOR SHAPES")
print("=" * 80)

batch_size = 1
seq_len = model_cfg["seq_len"]
vocab_size = model_cfg["vocab_size"]
num_puzzle_identifiers = model_cfg["num_puzzle_identifiers"]

print(f"\nBatch structure:")
print(f"  inputs: torch.Tensor of shape ({batch_size}, {seq_len})")
print(f"    - dtype: torch.int32 or torch.long")
print(f"    - values: token IDs in range [0, {vocab_size - 1}]")
print(f"\n  labels: torch.Tensor of shape ({batch_size}, {seq_len})")
print(f"    - dtype: torch.int32 or torch.long")
print(f"    - values: token IDs in range [0, {vocab_size - 1}] or -100 (ignore)")
print(f"\n  puzzle_identifiers: torch.Tensor of shape ({batch_size},)")
print(f"    - dtype: torch.int32 or torch.long")
print(f"    - values: puzzle IDs in range [0, {num_puzzle_identifiers - 1}]")

print(f"\nOutput structure:")
print(f"  logits: torch.Tensor of shape ({batch_size}, {seq_len}, {vocab_size})")
print(f"    - unnormalized logits for each token position")

# ============================================================================
# Create a dummy batch for testing
# ============================================================================

print("\n" + "=" * 80)
print("CREATING DUMMY BATCH FOR TESTING")
print("=" * 80)

dummy_batch = {
    "inputs": torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda"),
    "labels": torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda"),
    "puzzle_identifiers": torch.randint(0, num_puzzle_identifiers, (batch_size,), device="cuda"),
}

print(f"\nDummy batch created:")
for key, value in dummy_batch.items():
    print(f"  {key}: {value.shape}, dtype={value.dtype}, device={value.device}")

# Initialize carry
print("\nInitializing carry...")
with torch.device("cuda"):
    with torch.no_grad():
        carry = model.initial_carry(dummy_batch)
        print(f"Carry initialized: {type(carry)}")
        print(f"  inner_carry.z_H: {carry.inner_carry.z_H.shape}")
        print(f"  inner_carry.z_L: {carry.inner_carry.z_L.shape}")
        print(f"  steps: {carry.steps.shape}")
        print(f"  halted: {carry.halted.shape}")

# Run one forward pass
print("\nRunning forward pass...")
with torch.no_grad():
    carry, loss, metrics, preds, all_finish = model(
        carry=carry,
        batch=dummy_batch,
        return_keys=["logits"]
    )
    
print(f"\nForward pass complete!")
print(f"  loss: {loss.item():.4f}")
print(f"  all_finish: {all_finish}")
print(f"  logits shape: {preds['logits'].shape}")
print(f"  metrics: {list(metrics.keys())}")

# Get predictions
pred_tokens = torch.argmax(preds['logits'], dim=-1)
print(f"  predicted tokens shape: {pred_tokens.shape}")

print("\n" + "=" * 80)
print("READY FOR INSPECTION!")
print("=" * 80)
print("\nAvailable objects:")
print("  - model: the full model (ACTLossHead)")
print("  - model.model: the ACT wrapper")
print("  - model.model.inner: the inner transformer model")
print("  - full_config: the full training config")
print("  - model_cfg: the model config")
print("  - dummy_batch: a sample batch")
print("  - carry: the model's recurrent state")
print("  - preds: predictions from forward pass")
print("\nUseful inspection functions:")
print("  - inspect_model_structure(model): detailed layer breakdown")
print("  - inspect_state_dict(model): all parameter shapes")

MODEL LOADED SUCCESSFULLY

Model Configuration:
  H_cycles: 3
  H_layers: 0
  L_cycles: 6
  L_layers: 2
  expansion: 4
  forward_dtype: bfloat16
  halt_exploration_prob: 0.1
  halt_max_steps: 16
  hidden_size: 512
  mlp_t: True
  no_ACT_continue: True
  num_heads: 8
  pos_encodings: none
  puzzle_emb_len: 16
  puzzle_emb_ndim: 512
  causal: False
  vocab_size: 11
  seq_len: 81
  num_puzzle_identifiers: 1
  batch_size: 768

Total parameters: 5,028,866
Trainable parameters: 5,028,866

MODEL STRUCTURE (HIGH LEVEL)
ACTLossHead(
  (model): TinyRecursiveReasoningModel_ACTV1(
    (inner): TinyRecursiveReasoningModel_ACTV1_Inner(
      (embed_tokens): CastedEmbedding()
      (lm_head): CastedLinear()
      (q_head): CastedLinear()
      (puzzle_emb): CastedSparseEmbedding()
      (L_level): TinyRecursiveReasoningModel_ACTV1ReasoningModule(
        (layers): ModuleList(
          (0-1): 2 x TinyRecursiveReasoningModel_ACTV1Block(
            (mlp_t): SwiGLU(
              (gate_up_proj): Casted

In [2]:
model

ACTLossHead(
  (model): TinyRecursiveReasoningModel_ACTV1(
    (inner): TinyRecursiveReasoningModel_ACTV1_Inner(
      (embed_tokens): CastedEmbedding()
      (lm_head): CastedLinear()
      (q_head): CastedLinear()
      (puzzle_emb): CastedSparseEmbedding()
      (L_level): TinyRecursiveReasoningModel_ACTV1ReasoningModule(
        (layers): ModuleList(
          (0-1): 2 x TinyRecursiveReasoningModel_ACTV1Block(
            (mlp_t): SwiGLU(
              (gate_up_proj): CastedLinear()
              (down_proj): CastedLinear()
            )
            (mlp): SwiGLU(
              (gate_up_proj): CastedLinear()
              (down_proj): CastedLinear()
            )
          )
        )
      )
    )
  )
)

In [3]:
# Test L_level with dummy input
import torch

L_level = model.model.inner.L_level
hidden_size = model_cfg["hidden_size"]
seq_len_with_emb = model_cfg["seq_len"] + model.model.inner.puzzle_emb_len
batch_size_test = 1

hidden_states = 10*torch.randn(batch_size_test, seq_len_with_emb, hidden_size, device="cuda", dtype=torch.bfloat16)
input_injection = torch.zeros(batch_size_test, seq_len_with_emb, hidden_size, device="cuda", dtype=torch.bfloat16)
cos_sin = model.model.inner.rotary_emb() if hasattr(model.model.inner, 'rotary_emb') else None

with torch.no_grad():
    output = L_level(hidden_states=hidden_states, input_injection=input_injection, cos_sin=cos_sin)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# We will use the 'model', 'dummy_batch', 'metrics', and 'preds'
# variables already in your notebook's memory from the previous cell.

# ============================================================================
# Analysis 1: Visualize ACT Ponder Time
# ============================================================================
print("\n" + "=" * 80)
print("ANALYSIS 1: ACT PONDER TIME")
print("=" * 80)

# The 'metrics' dictionary from the forward pass should contain the 
# ponder time or step count for each token. We need to find the key.
print(f"Available metrics keys: {metrics.keys()}")

# !!! IMPORTANT !!!
# Replace 'act_steps' with the actual key from the list above.
# It might be 'ponder_time', 'steps', 'act_steps', or similar.
# For this demo, we'll assume the key is 'act_steps'.
PONDER_METRIC_KEY = 'act_steps' # <-- !!! CHECK AND CHANGE THIS KEY !!!

if PONDER_METRIC_KEY in metrics:
    # Fetch the ponder times [batch_size, seq_len]
    ponder_times = metrics[PONDER_METRIC_KEY].cpu().detach().numpy()
    
    # Select the first puzzle in the batch to visualize
    ponder_times_sample = ponder_times[0:1, :] # Keep 2D shape [1, seq_len]

    print(f"\nVisualizing ponder time for sample 0 (Shape: {ponder_times_sample.shape})")

    # Plot
    plt.figure(figsize=(20, 2))
    plt.imshow(ponder_times_sample, aspect='auto', cmap='viridis', interpolation='none')
    plt.colorbar(label='Ponder Steps')
    plt.xlabel('Token Position (Sequence Length)')
    plt.yticks([])
    plt.title('ACT Ponder Time (Computational Effort) per Token')
    plt.show()

    print(f"\nInterpretation:")
    print("> High values (brighter colors) show where the model 'thought' longer.")
    print("> This often corresponds to more difficult parts of the puzzle.")
    print("> NOTE: Since this is a 'dummy_batch' of random data, the pattern")
    print("  is meaningless. Run this with a *real* puzzle batch to see real insights.")

else:
    print(f"\nCould not find key '{PONDER_METRIC_KEY}' in metrics. Skipping ACT visualization.")


# ============================================================================
# Analysis 2: Input Gradient Saliency (Final Correction)
# ============================================================================
print("\n" + "=" * 80)
print("ANALYSIS 2: INPUT GRADIENT SALIENCY")
print("=" * 80)

# --- Saliency Helper Function (Corrected) ---
def get_input_saliency(model, batch, position_to_explain, sample_index=0):
    """
    Computes the gradient of a specific output logit w.r.t. input embeddings.
    """
    print(f"\nCalculating saliency for prediction at position {position_to_explain}...")
    
    # We will store the original 'requires_grad' state for all parameters
    original_grad_states = {}
    
    try:
        # 0. Set model to eval mode BUT enable grads on all parameters
        model.eval() 
        for name, p in model.named_parameters():
            original_grad_states[name] = p.requires_grad
            p.requires_grad = True # <<< THE KEY FIX
            
        model.zero_grad()

        # 1. We need to "hook" the input embeddings
        embeddings_tensor = None
        
        def hook_fn(module, input, output):
            nonlocal embeddings_tensor
            embeddings_tensor = output
            embeddings_tensor.retain_grad() # Crucial to save grad
        
        hook = model.model.inner.embed_tokens.register_forward_hook(hook_fn)

        # 2. Re-initialize carry and run forward pass *with grads enabled*
        try:
            with torch.device("cuda"), torch.enable_grad():
                carry = model.initial_carry(batch)
                
                # Forward pass *with* grads
                carry, loss, metrics, preds, all_finish = model(
                    carry=carry,
                    batch=batch,
                    return_keys=["logits"]
                )
        finally:
            hook.remove() # Always remove the hook!

        if embeddings_tensor is None:
            print("Error: Hook did not capture embeddings.")
            return

        # 3. Define the "decision" to explain
        logits_at_pos = preds['logits'][sample_index, position_to_explain] # [Vocab]
        predicted_token_idx = torch.argmax(logits_at_pos).item()
        score_to_explain = logits_at_pos[predicted_token_idx]

        print(f"  > Explaining prediction for token {predicted_token_idx} at pos {position_to_explain}.")

        # 4. Backward pass: Calculate gradients
        # This should FINALLY work!
        score_to_explain.backward()

        # 5. Get the saliency map from the gradient
        if embeddings_tensor.grad is None:
            print("Error: Gradients were not populated on the embeddings tensor.")
            return
            
        saliency = embeddings_tensor.grad[sample_index].norm(dim=1) # [SeqLen]
        
        return saliency.cpu().detach().numpy()

    finally:
        # 6. Restore original parameter grad states
        print("Restoring original model parameter grad states...")
        for name, p in model.named_parameters():
            if name in original_grad_states:
                p.requires_grad = original_grad_states[name]
        print("Done.")

# --- Run Saliency ---

# Explain the prediction at token position 40 (e.g., a specific Sudoku cell)
POS_TO_EXPLAIN = 40 
SAMPLE_TO_EXPLAIN = 0 # First item in batch

# NOTE: This uses 'dummy_batch'. For real insights, replace 'dummy_batch'
# with a real data batch from your dataloader!
saliency_map = get_input_saliency(model, dummy_batch, POS_TO_EXPLAIN, SAMPLE_TO_EXPLAIN)

# 7. Plot the saliency map
if saliency_map is not None:
    plt.figure(figsize=(20, 2))
    plt.imshow(saliency_map[None, :], aspect='auto', cmap='hot', interpolation='none')
    plt.colorbar(label='Saliency (Gradient Norm)')
    plt.xlabel('Input Token Position')
    plt.yticks([])
    plt.title(f'Input Saliency for Prediction at Position {POS_TO_EXPLAIN}')
    plt.show()

    print(f"\nInterpretation:")
    print("> High values (brighter colors) show which *input* tokens were most")
    print(f"  influential for the model's prediction at position {POS_TO_EXPLAIN}.")
    print("> This is the 'Grad-CAM' for your transformer.")
    print("> NOTE: Again, this result is random because the input is 'dummy_batch'.")
    print("  On a real Sudoku puzzle, you would see it focus on the relevant")
    print("  row, column, and 3x3 box for the cell it's trying to solve.")
else:
    print("Saliency map calculation failed.")


print("\n" + "=" * 80)
print("ANALYSIS CELL COMPLETE")
print("=" * 80)


ANALYSIS 1: ACT PONDER TIME
Available metrics keys: dict_keys(['count', 'accuracy', 'exact_accuracy', 'q_halt_accuracy', 'steps', 'lm_loss', 'q_halt_loss'])

Could not find key 'act_steps' in metrics. Skipping ACT visualization.

ANALYSIS 2: INPUT GRADIENT SALIENCY

Calculating saliency for prediction at position 40...
  > Explaining prediction for token 4 at pos 40.


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn