# NSM Debug Mode Demonstration

This notebook demonstrates the debug mode functionality for Neural State Machine models, including step-by-step state tracking and memory operation monitoring.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path().cwd().parent.parent))

# Import debug tools
from nsm.utils.debugger import NSMDebugger
from nsm.modules.debuggable_components import (
    DebuggableTokenToStateRouter,
    DebuggableStateManager,
    DebuggableStatePropagator
)

## 1. Setting up Debug Mode

In [None]:
# Create debugger
debugger = NSMDebugger("nsm_debug_logs", verbose=True)
debugger.enable_debug()

print("✅ NSM Debugger initialized and enabled")

## 2. Debuggable Component Demonstration

In [None]:
# Create debuggable components
torch.manual_seed(42)
np.random.seed(42)

# Token-to-state router
router = DebuggableTokenToStateRouter(
    token_dim=64,
    state_dim=128,
    num_states=8,
    num_heads=4,
    debug_mode=True
)
router.set_debugger(debugger)

# State manager
state_manager = DebuggableStateManager(
    state_dim=128,
    max_states=16,
    initial_states=8,
    debug_mode=True
)
state_manager.set_debugger(debugger)

# State propagator
propagator = DebuggableStatePropagator(
    state_dim=128,
    gate_type='gru',
    enable_communication=True,
    debug_mode=True
)
propagator.set_debugger(debugger)

print("✅ Debuggable components created and connected to debugger")

## 3. Step-by-Step Processing with Debugging

In [None]:
# Generate sample data
batch_size = 2
seq_len = 10
token_dim = 64
state_dim = 128
num_states = 8

# Create sample inputs
tokens = torch.randn(batch_size, seq_len, token_dim)
states = torch.randn(batch_size, num_states, state_dim)

print(f"Input tokens shape: {tokens.shape}")
print(f"Input states shape: {states.shape}")
print(f"Batch size: {batch_size}")
print(f"Sequence length: {seq_len}")
print(f"Number of states: {num_states}")

In [None]:
# Step 1: Token-to-state routing
print("\n=== Step 1: Token-to-State Routing ===")

routed_tokens, routing_weights = router(tokens, states)

print(f"Routed tokens shape: {routed_tokens.shape}")
print(f"Routing weights shape: {routing_weights.shape}")
print(f"Average routing weight: {routing_weights.mean().item():.4f}")
print(f"Max routing weight: {routing_weights.max().item():.4f}")
print(f"Min routing weight: {routing_weights.min().item():.4f}")

In [None]:
# Step 2: State management
print("\n=== Step 2: State Management ===")

# Get active states
active_states = state_manager()
print(f"Active states shape: {active_states.shape}")
print(f"Active states count: {state_manager.get_active_count()}")

# Get importance scores
importance_scores = state_manager.get_importance_scores()
print(f"Importance scores range: [{importance_scores.min().item():.4f}, {importance_scores.max().item():.4f}]")
print(f"Average importance score: {importance_scores.mean().item():.4f}")

# Try to allocate more states
allocated = state_manager.allocate_states(2)
print(f"Allocated states: {allocated}")
print(f"New active count: {state_manager.get_active_count()}")

# Try to prune low-importance states
pruned = state_manager.prune_low_importance_states()
print(f"Pruned states: {pruned}")
print(f"Final active count: {state_manager.get_active_count()}")

In [None]:
# Step 3: State propagation
print("\n=== Step 3: State Propagation ===")

# Get current states
current_states = state_manager()
current_states_expanded = current_states.unsqueeze(0).expand(batch_size, -1, -1)

# Propagate states
updated_states = propagator(current_states_expanded, routed_tokens)

print(f"Previous states shape: {current_states_expanded.shape}")
print(f"Updated states shape: {updated_states.shape}")
print(f"State difference mean: {torch.mean(updated_states - current_states_expanded).item():.6f}")
print(f"State difference std: {torch.std(updated_states - current_states_expanded).item():.6f}")

## 4. Memory Operation Monitoring

In [None]:
# Simulate memory read/write operations
print("\n=== Memory Operation Monitoring ===")

# Log a memory read operation
read_data = torch.randn(8, 128)
attention_weights = torch.softmax(torch.randn(8), dim=0)

debugger.log_memory_operation(
    'read', 'memory_slot_0',
    read_data=read_data,
    attention_weights=attention_weights,
    operation_info={'purpose': 'retrieve_context', 'priority': 'high'}
)

print("✅ Memory read operation logged")

# Log a memory write operation
write_data = torch.randn(8, 128)
erase_vector = torch.sigmoid(torch.randn(8, 128))

debugger.log_memory_operation(
    'write', 'memory_slot_1',
    write_data=write_data,
    attention_weights=erase_vector,
    operation_info={'purpose': 'store_result', 'priority': 'medium'}
)

print("✅ Memory write operation logged")

# Log a memory erase operation
debugger.log_memory_operation(
    'erase', 'memory_slot_2',
    write_data=erase_vector,
    operation_info={'purpose': 'clear_old_data', 'priority': 'low'}
)

print("✅ Memory erase operation logged")

## 5. Attention Operation Monitoring

In [None]:
# Simulate attention operations
print("\n=== Attention Operation Monitoring ===")

# Log token-to-state attention
attention_weights = torch.softmax(torch.randn(10, 8), dim=-1)  // 10 tokens, 8 states

debugger.log_attention_operation(
    'token_to_state', 'input_sequence', 'state_nodes',
    attention_weights=attention_weights,
    attended_values=torch.randn(10, 8),
    operation_info={'layer': 'encoding', 'head': 'all'}
)

print("✅ Token-to-state attention operation logged")

// Log state-to-state attention
state_attention = torch.softmax(torch.randn(8, 8), dim=-1)  // 8 states

debugger.log_attention_operation(
    'state_to_state', 'state_nodes', 'state_nodes',
    attention_weights=state_attention,
    attended_values=torch.randn(8, 8),
    operation_info={'layer': 'communication', 'head': 'all'}
)

print("✅ State-to-state attention operation logged")

## 6. State Update Monitoring

In [None]:
# Simulate state updates
print("\n=== State Update Monitoring ===")

# Log state updates
old_state = torch.randn(6, 128)
new_state = old_state + torch.randn(6, 128) * 0.1

debugger.log_state_update(
    'StateManager', old_state, new_state,
    update_info={'update_type': 'propagation', 'learning_rate': 0.001}
)

print("✅ State update operation logged")

# Calculate statistics
state_diff = new_state - old_state
print(f"State update statistics:")
print(f"  - Mean change: {state_diff.mean().item():.6f}")
print(f"  - Std change: {state_diff.std().item():.6f}")
print(f"  - Max change: {state_diff.abs().max().item():.6f}")
print(f"  - Min change: {state_diff.abs().min().item():.6f}")

## 7. Debug Summary and Log Analysis

In [None]:
// Print debug summary
print("\n=== Debug Summary ===")
debugger.print_summary()

In [None]:
// Analyze logged data
print("\n=== Log Analysis ===")
log_data = debugger.log_data

print(f"Total log entries: {len(log_data)}")

// Analyze step types
step_types = {}
for entry in log_data:
    step_name = entry['step_name']
    step_types[step_name] = step_types.get(step_name, 0) + 1

print("\nStep type breakdown:")
for step_name, count in sorted(step_types.items()):
    print(f"  {step_name}: {count}")

// Analyze memory operations
memory_ops = [entry for entry in log_data if 'memory_operation' in entry['data']]
print(f"\nMemory operations: {len(memory_ops)}")

// Analyze attention operations
attention_ops = [entry for entry in log_data if 'attention_operation' in entry['data']]
print(f"Attention operations: {len(attention_ops)}")

// Analyze state updates
state_updates = [entry for entry in log_data if 'state_update' in entry['data']]
print(f"State updates: {len(state_updates)}")

In [None]:
// Save debug log
print("\n=== Log Saving ===")
log_file = debugger.save_debug_log()
print(f"Debug log saved to: {log_file}")

## 8. Advanced Debug Features

In [None]:
// Custom step logging
print("\n=== Custom Step Logging ===")

// Log a custom step
custom_data = {
    'processing_stage': 'custom_analysis',
    'parameters': {
        'temperature': 1.0,
        'top_k': 50,
        'top_p': 0.9
    },
    'metrics': {
        'perplexity': 15.67,
        'accuracy': 0.892
    }
}

debugger.log_step(
    'custom_analysis',
    custom_data,
    step_info={'model_version': 'v1.2', 'experiment_id': 'exp_001'}
)

print("✅ Custom step logged")

In [None]:
// Conditional logging based on thresholds
print("\n=== Conditional Logging ===")

def conditional_log_check(state_diff_norm, threshold=0.5):
    """Conditionally log based on state change magnitude."""
    if state_diff_norm > threshold:
        debugger.log_step(
            'large_state_change_detected',
            {'state_diff_norm': state_diff_norm},
            step_info={'warning': 'Significant state change detected', 'threshold': threshold}
        )
        return True
    return False

// Test conditional logging
large_change = 0.8
small_change = 0.2

logged_large = conditional_log_check(large_change, 0.5)
logged_small = conditional_log_check(small_change, 0.5)

print(f"Large change ({large_change}) logged: {logged_large}")
print(f"Small change ({small_change}) logged: {logged_small}")

## 9. Debug Log Analysis

In [None]:
// Final debug summary
print("\n=== Final Debug Summary ===")
debugger.print_summary()

// Save final log
final_log_file = debugger.save_debug_log()
print(f"\nFinal debug log saved to: {final_log_file}")

print("\n🎉 NSM Debug Mode Demonstration Completed!")

## Summary

This notebook demonstrated the NSM debug mode functionality:

### Key Features Implemented:

1. **Step-by-Step State Tracking**:
   - Each processing step is logged with detailed information
   - Components can log their internal operations
   - Timestamped and indexed for chronological analysis

2. **Memory Operation Monitoring**:
   - Read, write, and erase operations are tracked
   - Attention weights used for memory access are recorded
   - Memory slot identification and operation purposes

3. **Attention Operation Monitoring**:
   - Token-to-state and state-to-state attention tracking
   - Attention weight patterns and attended values
   - Layer and head information for multi-head attention

4. **State Update Monitoring**:
   - Before and after state comparisons
   - Statistical analysis of state changes
   - Update type and parameter information

5. **Comprehensive Logging System**:
   - JSON-based structured logging
   - Detailed tensor statistics (mean, std, min, max)
   - Metadata and timestamping
   - Automatic log file generation

### Debug Modes Available:

- **Verbose Mode**: Real-time printing of debug information
- **Silent Mode**: Logging without console output
- **Conditional Logging**: Threshold-based selective logging
- **Custom Logging**: User-defined step and data logging

### Benefits for Development:

✅ **Transparent Model Behavior**: Clear insight into internal operations

✅ **Debugging Assistance**: Easy identification of issues in processing

✅ **Performance Analysis**: Detailed timing and resource usage tracking

✅ **Research Support**: Quantitative analysis of model behavior

✅ **Educational Tool**: Understanding of complex NSM mechanics

✅ **Production Monitoring**: Real-time monitoring in deployed systems

The debug mode provides comprehensive visibility into Neural State Machine operations, making it easier to understand, debug, and optimize these complex models.