# Memory and Attention Visualization for Neural State Machines

This notebook focuses specifically on visualizing attention maps and external memory contents in Neural State Machine models.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path().cwd().parent.parent))

# Import NSM components and visualization tools
from nsm.utils.advanced_visualizer import AdvancedNSMVisualizer
from nsm.modules.ntm_memory import NTMMemory

## 1. Attention Map Visualization

In [None]:
# Create advanced visualizer
visualizer = AdvancedNSMVisualizer(figsize=(12, 10))

In [None]:
# Generate different types of attention patterns
torch.manual_seed(42)
np.random.seed(42)

# 1. Self-attention pattern
self_attention = torch.softmax(torch.randn(8, 8), dim=-1)

# 2. Token-to-state attention
token_state_attention = torch.softmax(torch.randn(12, 6), dim=-1)  # 12 tokens, 6 states

# 3. State-to-state attention
state_state_attention = torch.softmax(torch.randn(6, 6), dim=-1)  # 6 states

# 4. Multi-head attention
multi_head_attention = torch.softmax(torch.randn(4, 8, 8), dim=-1)  # 4 heads, 8x8

print("Attention patterns generated successfully!")

In [None]:
# Visualize self-attention
fig1 = visualizer.plot_attention_map(
    self_attention,
    title="Self-Attention Pattern",
    x_labels=[f"Pos{i}" for i in range(8)],
    y_labels=[f"Query{i}" for i in range(8)]
)

print("Self-attention visualization complete.")

In [None]:
# Visualize token-to-state attention
fig2 = visualizer.plot_token_to_state_routing(
    token_state_attention,
    token_labels=[f"T{i}" for i in range(12)],
    state_labels=[f"S{i}" for i in range(6)],
    title="Token-to-State Attention Routing"
)

print("Token-to-state attention visualization complete.")

In [None]:
# Visualize state-to-state attention
fig3 = visualizer.plot_state_communication(
    state_state_attention,
    state_labels=[f"State{i}" for i in range(6)],
    title="State-to-State Communication Pattern"
)

print("State-to-state attention visualization complete.")

In [None]:
# Visualize multi-head attention patterns
fig4 = visualizer.plot_attention_patterns(
    multi_head_attention,
    pattern_names=[f"Head {i+1}" for i in range(4)],
    title="Multi-Head Attention Patterns"
)

print("Multi-head attention visualization complete.")

## 2. External Memory Content Visualization

In [None]:
# Create NTM memory for demonstration
ntm_memory = NTMMemory(mem_size=32, mem_dim=16, num_read_heads=2, num_write_heads=1)

# Initialize with some values
memory_content = ntm_memory.get_memory_state()

print(f"NTM Memory shape: {memory_content.shape}")
print(f"Memory content range: [{memory_content.min():.3f}, {memory_content.max():.3f}]")

In [None]:
# Visualize memory content
fig5 = visualizer.plot_memory_content(
    memory_content,
    title="NTM External Memory Content"
)

print("Memory content visualization complete.")

In [None]:
# Simulate memory evolution
def simulate_memory_evolution(ntm, steps=5):
    """Simulate memory evolution over multiple steps."""
    memory_states = []
    
    # Get initial memory
    initial_memory = ntm.get_memory_state()
    memory_states.append(initial_memory)
    
    # Simulate memory operations
    for step in range(steps):
        # Generate random parameters for memory operations
        batch_size = 1
        read_keys = torch.randn(batch_size, ntm.num_read_heads, ntm.mem_dim)
        write_keys = torch.randn(batch_size, ntm.num_write_heads, ntm.mem_dim)
        read_strengths = torch.randn(batch_size, ntm.num_read_heads)
        write_strengths = torch.randn(batch_size, ntm.num_write_heads)
        erase_vectors = torch.sigmoid(torch.randn(batch_size, ntm.num_write_heads, ntm.mem_dim))
        add_vectors = torch.tanh(torch.randn(batch_size, ntm.num_write_heads, ntm.mem_dim))
        
        # Perform memory operations
        read_vectors, updated_memory = ntm(
            read_keys, write_keys, read_strengths, write_strengths,
            erase_vectors, add_vectors
        )
        
        # Store memory state
        memory_states.append(updated_memory)
    
    return memory_states

# Simulate memory evolution
memory_evolution = simulate_memory_evolution(ntm_memory, steps=4)
print(f"Memory evolution steps: {len(memory_evolution)}")

In [None]:
# Visualize memory evolution
fig6 = visualizer.plot_state_evolution(
    memory_evolution,
    title="Memory Content Evolution Over Time"
)

print("Memory evolution visualization complete.")

In [None]:
# Visualize memory slot importance
importance_scores = ntm_memory.get_importance_scores()

fig7 = visualizer.plot_memory_importance(
    importance_scores,
    title="Memory Slot Importance Scores"
)

print("Memory importance visualization complete.")
print(f"Average importance score: {importance_scores.mean().item():.3f}")

## 3. Memory Read/Write Operations Visualization

In [None]:
# Simulate read/write operations
def simulate_read_write_operations(ntm):
    """Simulate read and write operations to visualize attention weights."""
    batch_size = 1
    
    # Generate parameters for memory operations
    read_keys = torch.randn(batch_size, ntm.num_read_heads, ntm.mem_dim)
    write_keys = torch.randn(batch_size, ntm.num_write_heads, ntm.mem_dim)
    read_strengths = torch.randn(batch_size, ntm.num_read_heads)
    write_strengths = torch.randn(batch_size, ntm.num_write_heads)
    erase_vectors = torch.sigmoid(torch.randn(batch_size, ntm.num_write_heads, ntm.mem_dim))
    add_vectors = torch.tanh(torch.randn(batch_size, ntm.num_write_heads, ntm.mem_dim))
    
    # Get initial read/write weights
    initial_read_weights = ntm.get_read_weights()
    initial_write_weights = ntm.get_write_weights()
    
    # Perform memory operations
    read_vectors, updated_memory = ntm(
        read_keys, write_keys, read_strengths, write_strengths,
        erase_vectors, add_vectors
    )
    
    # Get updated weights
    final_read_weights = ntm.get_read_weights()
    final_write_weights = ntm.get_write_weights()
    
    return {
        'initial_read': initial_read_weights,
        'initial_write': initial_write_weights,
        'final_read': final_read_weights,
        'final_write': final_write_weights,
        'read_vectors': read_vectors,
        'memory': updated_memory
}

# Simulate operations
operation_results = simulate_read_write_operations(ntm_memory)
print("Memory operations simulated successfully!")

In [None]:
# Visualize read/write operations
fig8 = visualizer.plot_memory_read_write_operations(
    operation_results['initial_read'],
    operation_results['initial_write'],
    memory_slots=[f"Slot{i}" for i in range(ntm_memory.mem_size)],
    title="Initial Memory Read/Write Attention Weights"
)

print("Initial read/write operations visualization complete.")

In [None]:
# Visualize final read/write operations
fig9 = visualizer.plot_memory_read_write_operations(
    operation_results['final_read'],
    operation_results['final_write'],
    memory_slots=[f"Slot{i}" for i in range(ntm_memory.mem_size)],
    title="Final Memory Read/Write Attention Weights"
)

print("Final read/write operations visualization complete.")

## 4. Comparative Attention Analysis

In [None]:
# Generate attention patterns with different characteristics
def generate_attention_patterns():
    """Generate various attention patterns for comparison."""
    size = 10
    
    # 1. Diagonal/Identity pattern (self-focused)
    diagonal = torch.eye(size) + torch.randn(size, size) * 0.1
    diagonal = torch.softmax(diagonal, dim=-1)
    
    # 2. Local pattern (neighbors)
    local_pattern = torch.zeros(size, size)
    for i in range(size):
        for j in range(max(0, i-2), min(size, i+3)):
            local_pattern[i, j] = 1.0
    local_pattern = torch.softmax(local_pattern + torch.randn(size, size) * 0.1, dim=-1)
    
    # 3. Global pattern (uniform)
    global_pattern = torch.ones(size, size) / size
    
    # 4. Random pattern
    random_pattern = torch.softmax(torch.randn(size, size), dim=-1)
    
    return {
        'diagonal': diagonal,
        'local': local_pattern,
        'global': global_pattern,
        'random': random_pattern
}

# Generate patterns
attention_patterns = generate_attention_patterns()
print("Attention patterns generated for comparison.")

In [None]:
# Visualize all patterns side by side
fig10 = visualizer.plot_attention_patterns(
    torch.stack(list(attention_patterns.values())),
    pattern_names=list(attention_patterns.keys()),
    title="Comparative Attention Patterns"
)

print("Comparative attention patterns visualization complete.")

In [None]:
# Analyze pattern characteristics
def analyze_attention_pattern(pattern, name):
    """Analyze characteristics of an attention pattern."""
    pattern_np = pattern.detach().cpu().numpy()
    
    # Calculate entropy (measure of focus/spread)
    entropy = -np.sum(pattern_np * np.log(pattern_np + 1e-8))
    
    # Calculate max attention
    max_attention = np.max(pattern_np)
    
    # Calculate sparsity (fraction of small values)
    sparsity = np.mean(pattern_np < 0.01)
    
    return {
        'name': name,
        'entropy': entropy,
        'max_attention': max_attention,
        'sparsity': sparsity
}

# Analyze all patterns
pattern_analysis = []
for name, pattern in attention_patterns.items():
    analysis = analyze_attention_pattern(pattern, name)
    pattern_analysis.append(analysis)
    
# Display analysis
import pandas as pd
analysis_df = pd.DataFrame(pattern_analysis)
analysis_df

## 5. Comprehensive Visualization Report

In [None]:
# Create comprehensive visualization data
comprehensive_data = {
    'attention_weights': self_attention,
    'token_state_attention': token_state_attention,
    'state_attention': state_state_attention,
    'memory_content': memory_content,
    'importance_scores': importance_scores,
    'initial_read_weights': operation_results['initial_read'],
    'initial_write_weights': operation_results['initial_write'],
    'multi_head_attention': multi_head_attention
}

# Generate comprehensive report
report_dir = visualizer.create_comprehensive_report(
    comprehensive_data,
    save_dir="memory_attention_reports"
)

print(f"Comprehensive visualization report generated in: {report_dir}")

## 6. Interactive Memory Analysis

In [None]:
# Create interactive analysis of memory content
def analyze_memory_content(memory_tensor):
    """Provide detailed analysis of memory content."""
    memory_np = memory_tensor.detach().cpu().numpy()
    
    analysis = {
        'shape': memory_np.shape,
        'total_elements': memory_np.size,
        'mean': np.mean(memory_np),
        'std': np.std(memory_np),
        'min': np.min(memory_np),
        'max': np.max(memory_np),
        'sparsity': np.mean(np.abs(memory_np) < 0.01),  # Fraction near zero
        'active_slots': np.mean(np.std(memory_np, axis=1) > 0.1)  # Slots with variation
}
    
    return analysis

# Analyze memory
memory_analysis = analyze_memory_content(memory_content)
print("Memory Content Analysis:")
for key, value in memory_analysis.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

In [None]:
# Interactive attention analysis
def analyze_attention_content(attention_tensor, name="Attention"):
    """Provide detailed analysis of attention content."""
    attention_np = attention_tensor.detach().cpu().numpy()
    
    # For 2D attention matrices
    if attention_np.ndim == 2:
        # Calculate row-wise entropy (focus per query)
        row_entropy = -np.sum(attention_np * np.log(attention_np + 1e-8), axis=1)
        
        # Calculate column-wise entropy (focus per key)
        col_entropy = -np.sum(attention_np * np.log(attention_np + 1e-8), axis=0)
        
        analysis = {
            'name': name,
            'shape': attention_np.shape,
            'mean': np.mean(attention_np),
            'std': np.std(attention_np),
            'max': np.max(attention_np),
            'min': np.min(attention_np),
            'avg_row_entropy': np.mean(row_entropy),
            'avg_col_entropy': np.mean(col_entropy),
            'focus_score': 1.0 - np.mean(row_entropy) / np.log(attention_np.shape[1])  # Normalized focus
}
    else:
        analysis = {
            'name': name,
            'shape': attention_np.shape,
            'mean': np.mean(attention_np),
            'std': np.std(attention_np),
            'max': np.max(attention_np),
            'min': np.min(attention_np)
}
    
    return analysis

# Analyze different attention types
attention_analysis = []
attention_analysis.append(analyze_attention_content(self_attention, "Self-Attention"))
attention_analysis.append(analyze_attention_content(token_state_attention, "Token-State Attention"))
attention_analysis.append(analyze_attention_content(state_state_attention, "State-State Attention"))

# Display analysis
attention_df = pd.DataFrame(attention_analysis)
attention_df

## Summary

This notebook demonstrates comprehensive visualization tools for memory and attention in Neural State Machine models:

### Attention Visualization:
1. **Self-Attention Patterns** - Traditional attention between sequence elements
2. **Token-to-State Routing** - How input tokens are distributed to state nodes
3. **State-to-State Communication** - Interaction patterns between states
4. **Multi-Head Attention** - Comparison of different attention heads
5. **Comparative Patterns** - Analysis of different attention characteristics

### Memory Visualization:
1. **Memory Content Heatmaps** - Visualization of external memory values
2. **Memory Evolution** - How memory changes over time
3. **Memory Importance** - Slot-wise importance scoring
4. **Read/Write Operations** - Visualization of memory access patterns
5. **Interactive Analysis** - Detailed statistical analysis of memory and attention

### Key Features:
- **Multiple Visualization Types**: Heatmaps, line plots, bar charts
- **Statistical Analysis**: Entropy, sparsity, focus metrics
- **Comparative Analysis**: Side-by-side pattern comparison
- **Automated Reporting**: Comprehensive report generation
- **Interactive Exploration**: Detailed data analysis tools

These visualization tools help interpret the internal mechanisms of NSM models and provide insights into their memory management and attention processes.