# Neural State Machine Visualization Tools

This notebook demonstrates the visualization tools for interpreting Neural State Machine models.

In [None]:
import torch
import numpy as np
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path().cwd().parent))

# Import visualization tools
from nsm.utils.visualizer import NSMVisualizer
from nsm.utils.advanced_visualizer import AdvancedNSMVisualizer

## 1. Basic Visualization Tools

In [None]:
# Create basic visualizer
visualizer = NSMVisualizer(figsize=(10, 8))

In [None]:
# Generate sample data
torch.manual_seed(42)
np.random.seed(42)

# Sample attention weights
attention_weights = torch.softmax(torch.randn(8, 8), dim=-1)

# Sample memory content
memory_content = torch.randn(16, 20)

# Sample state evolution
states = [torch.randn(8, 16) for _ in range(5)]

# Sample importance scores
importance_scores = torch.sigmoid(torch.randn(16))

print("Sample data generated successfully!")

In [None]:
# Plot attention map
fig1 = visualizer.plot_attention_map(
    attention_weights, 
    title="Sample Attention Map",
    x_labels=[f"Pos{i}" for i in range(8)],
    y_labels=[f"Query{i}" for i in range(8)]
)

In [None]:
# Plot memory content
fig2 = visualizer.plot_memory_content(
    memory_content,
    title="Sample Memory Content"
)

In [None]:
# Plot state evolution
fig3 = visualizer.plot_state_evolution(
    states,
    title="Sample State Evolution"
)

In [None]:
# Plot memory importance
fig4 = visualizer.plot_memory_importance(
    importance_scores,
    title="Sample Memory Importance"
)

## 2. Advanced NSM-Specific Visualizations

In [None]:
# Create advanced visualizer
advanced_visualizer = AdvancedNSMVisualizer(figsize=(12, 8))

In [None]:
# Generate advanced sample data

# Token-to-state routing weights
routing_weights = torch.softmax(torch.randn(12, 8), dim=-1)  # 12 tokens, 8 states

# State-to-state communication
state_attention = torch.softmax(torch.randn(8, 8), dim=-1)  # 8 states

# Memory read/write operations
read_weights = torch.softmax(torch.randn(10), dim=0)  // 10 memory slots
write_weights = torch.softmax(torch.randn(10), dim=0)

# State dynamics
state_trajectories = [torch.randn(6, 12) for _ in range(10)]  // 6 states, 12 dims, 10 steps

print("Advanced sample data generated successfully!")

In [None]:
// Plot token-to-state routing
fig5 = advanced_visualizer.plot_token_to_state_routing(
    routing_weights,
    token_labels=[f"T{i}" for i in range(12)],
    state_labels=[f"S{i}" for i in range(8)],
    title="Token-to-State Routing"
)

In [None]:
// Plot state-to-state communication
fig6 = advanced_visualizer.plot_state_communication(
    state_attention,
    state_labels=[f"State{i}" for i in range(8)],
    title="State-to-State Communication"
)

In [None]:
// Plot memory read/write operations
fig7 = advanced_visualizer.plot_memory_read_write_operations(
    read_weights, write_weights,
    memory_slots=[f"Slot{i}" for i in range(10)],
    title="Memory Read/Write Operations"
)

In [None]:
// Plot state dynamics
fig8 = advanced_visualizer.plot_state_dynamics(
    state_trajectories,
    state_labels=[f"S{i}" for i in range(6)],
    metrics=['norm', 'mean', 'std'],
    title="State Dynamics Over Time"
)

## 3. Comprehensive Visualization Report

In [None]:
// Create comprehensive visualization data
visualization_data = {
    'attention_weights': attention_weights,
    'routing_weights': routing_weights,
    'memory_content': memory_content,
    'state_attention': state_attention,
    'importance_scores': importance_scores
};

// Generate comprehensive report
report_dir = advanced_visualizer.create_comprehensive_report(
    visualization_data,
    save_dir="visualization_reports"
);

print(f"Comprehensive report generated in: {report_dir}");

## 4. Interactive Analysis

In [None]:
// Create interactive summary
summary_df = visualizer.create_interactive_summary(visualization_data);
summary_df;

In [None]:
// Analyze attention patterns
attention_np = attention_weights.detach().cpu().numpy();

print("Attention Pattern Analysis:");
print(f"  - Max attention weight: {np.max(attention_np):.4f}");
print(f"  - Min attention weight: {np.min(attention_np):.4f}");
print(f"  - Mean attention weight: {np.mean(attention_np):.4f}");
print(f"  - Std attention weight: {np.std(attention_np):.4f}");

// Find strongest attention connections
max_indices = np.unravel_index(np.argmax(attention_np), attention_np.shape);
print(f"  - Strongest connection: Query {max_indices[0]} → Key {max_indices[1]} = {attention_np[max_indices]:.4f}");

In [None]:
// Analyze memory content
memory_np = memory_content.detach().cpu().numpy();

print("\nMemory Content Analysis:");
print(f"  - Memory slots: {memory_np.shape[0]}");
print(f"  - Memory dimensions: {memory_np.shape[1]}");
print(f"  - Max value: {np.max(memory_np):.4f}");
print(f"  - Min value: {np.min(memory_np):.4f}");
print(f"  - Mean value: {np.mean(memory_np):.4f}");
print(f"  - Std value: {np.std(memory_np):.4f}");

// Find most active memory slots
slot_means = np.mean(np.abs(memory_np), axis=1);
most_active_slot = np.argmax(slot_means);
print(f"  - Most active slot: {most_active_slot} (mean abs: {slot_means[most_active_slot]:.4f})");

## 5. Custom Visualization Examples

In [None]:
// Create custom attention pattern visualization
function plot_custom_attention_pattern(pattern_type="diagonal") {
    """Create and visualize custom attention patterns."""
    size = 10;
    
    if (pattern_type == "diagonal") {
        // Diagonal pattern
        pattern = torch.eye(size);
    } else if (pattern_type == "local") {
        // Local attention pattern
        pattern = torch.zeros(size, size);
        for (i in range(size)) {
            for (j in range(max(0, i-2), min(size, i+3))) {
                pattern[i, j] = 1.0;
            }
        }
    } else if (pattern_type == "random") {
        // Random pattern
        pattern = torch.softmax(torch.randn(size, size), dim=-1);
    } else {
        pattern = torch.ones(size, size) / size;
    }
    
    // Normalize
    pattern = torch.softmax(pattern, dim=-1);
    
    // Plot
    fig = visualizer.plot_attention_map(
        pattern,
        title=f"Custom Attention Pattern: {pattern_type.capitalize()}",
        x_labels=[f"P{i}" for i in range(size)],
        y_labels=[f"Q{i}" for i in range(size)]
    );
    
    return fig;
}

// Test custom patterns
fig9 = plot_custom_attention_pattern("diagonal");
fig10 = plot_custom_attention_pattern("local");
fig11 = plot_custom_attention_pattern("random");

In [None]:
// Create memory evolution visualization
function plot_memory_evolution(num_steps=5) {
    """Visualize how memory content evolves over time."""
    torch.manual_seed(123);
    
    // Simulate memory evolution
    initial_memory = torch.randn(8, 12);
    memory_states = [initial_memory];
    
    for (i in range(num_steps - 1)) {
        // Simulate some change
        change = torch.randn(8, 12) * 0.1;
        new_memory = memory_states[-1] + change;
        memory_states.append(new_memory);
    }
    
    // Plot evolution
    fig = visualizer.plot_state_evolution(
        memory_states,
        title="Memory Content Evolution Over Time"
    );
    
    return fig;
}

// Test memory evolution
fig12 = plot_memory_evolution(6);

## Summary

This notebook demonstrates:

1. **Basic Visualization Tools**:
   - Attention maps
   - Memory content visualization
   - State evolution tracking
   - Memory importance scoring

2. **Advanced NSM-Specific Visualizations**:
   - Token-to-state routing
   - State-to-state communication
   - Memory read/write operations
   - State dynamics over time

3. **Comprehensive Reporting**:
   - Automated report generation
   - Interactive data analysis
   - Statistical summaries

4. **Custom Visualization Examples**:
   - Custom attention patterns
   - Memory evolution tracking

These tools help interpret the internal workings of Neural State Machine models and provide insights into their decision-making processes.