# Token-to-State Routing Visualization

This notebook visualizes how input tokens are routed to different state nodes in a Neural State Machine. We'll create heatmaps showing which tokens attend to which states and optionally overlay importance scores of state nodes.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys

# Add src to path
sys.path.insert(0, '../src')

from nsm.components import TokenToStateRouter, StateManager
from nsm.models.simple_nsm import SimpleNSM

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Create a Sample Model

We'll create a sample NSM model to demonstrate routing visualization.

In [None]:
# Create a sample NSM model
model = SimpleNSM(
    input_dim=128,      # Input dimension
    state_dim=64,       # State dimension
    num_states=16,      # Number of states
    output_dim=10,      # Classification output
    gate_type='gru'     # Gating mechanism
).to(device)

# Create a TokenToStateRouter for visualization
router = TokenToStateRouter(
    token_dim=128,
    state_dim=64,
    num_states=16,
    num_heads=4
).to(device)

# Create a StateManager for importance scores
state_manager = StateManager(
    state_dim=64,
    max_states=16,
    initial_states=16,
    prune_threshold=0.3
)

print("Model components created successfully!")

## Generate Sample Data

We'll create synthetic token sequences to visualize the routing patterns.

In [None]:
# Generate sample tokens
batch_size = 4
seq_len = 20
token_dim = 128
num_states = 16
state_dim = 64

# Create sample token sequences
sample_tokens = torch.randn(batch_size, seq_len, token_dim).to(device)
sample_states = torch.randn(batch_size, num_states, state_dim).to(device)

print(f"Sample tokens shape: {sample_tokens.shape}")
print(f"Sample states shape: {sample_states.shape}")

## Visualize Routing Patterns

Let's create heatmaps showing how tokens are routed to states.

In [None]:
def visualize_routing_heatmap(tokens, states, router, example_idx=0, title="Token-to-State Routing Heatmap"):
    """
    Visualize the routing heatmap for a specific example in the batch.
    
    Args:
        tokens (torch.Tensor): Input tokens [batch_size, seq_len, token_dim]
        states (torch.Tensor): State vectors [batch_size, num_states, state_dim]
        router (TokenToStateRouter): Router module
        example_idx (int): Index of example in batch to visualize
        title (str): Title for the plot
    """
    with torch.no_grad():
        # Get routing weights
        _, routing_weights = router(tokens, states)
        
        # Select specific example
        example_weights = routing_weights[example_idx].cpu().numpy()
        
        # Create heatmap
        plt.figure(figsize=(10, 8))
        sns.heatmap(
            example_weights,
            annot=True,
            fmt='.2f',
            cmap='viridis',
            cbar=True,
            xticklabels=[f'State {i}' for i in range(example_weights.shape[1])],
            yticklabels=[f'Token {i}' for i in range(example_weights.shape[0])]
        )
        plt.title(title)
        plt.xlabel('State Nodes')
        plt.ylabel('Tokens')
        plt.tight_layout()
        plt.show()
        
        return example_weights

# Visualize routing for first example
routing_weights = visualize_routing_heatmap(
    sample_tokens, 
    sample_states, 
    router, 
    example_idx=0,
    title="Token-to-State Routing Heatmap (Example 0)"
)

## Visualize Routing with State Importance Overlay

Let's enhance the visualization by overlaying state importance scores.

In [None]:
def visualize_routing_with_importance(tokens, states, router, state_manager, example_idx=0):
    """
    Visualize routing heatmap with state importance scores overlaid.
    
    Args:
        tokens (torch.Tensor): Input tokens [batch_size, seq_len, token_dim]
        states (torch.Tensor): State vectors [batch_size, num_states, state_dim]
        router (TokenToStateRouter): Router module
        state_manager (StateManager): State manager with importance scores
        example_idx (int): Index of example in batch to visualize
    """
    with torch.no_grad():
        # Get routing weights
        _, routing_weights = router(tokens, states)
        
        # Select specific example
        example_weights = routing_weights[example_idx].cpu().numpy()
        
        # Get state importance scores
        importance_scores = state_manager.get_importance_scores().cpu().numpy()
        
        # Create figure with subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Plot routing heatmap
        sns.heatmap(
            example_weights,
            annot=True,
            fmt='.2f',
            cmap='viridis',
            cbar=True,
            xticklabels=[f'State {i}' for i in range(example_weights.shape[1])],
            yticklabels=[f'Token {i}' for i in range(example_weights.shape[0])],
            ax=ax1
        )
        ax1.set_title(f'Token-to-State Routing (Example {example_idx})')
        ax1.set_xlabel('State Nodes')
        ax1.set_ylabel('Tokens')
        
        # Plot state importance
        bars = ax2.bar(range(len(importance_scores)), importance_scores, color='skyblue')
        ax2.set_xlabel('State Nodes')
        ax2.set_ylabel('Importance Score')
        ax2.set_title('State Importance Scores')
        ax2.set_xticks(range(len(importance_scores)))
        ax2.set_xticklabels([f'State {i}' for i in range(len(importance_scores))], rotation=45)
        
        # Add value labels on bars
        for i, bar in enumerate(bars):
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2f}',
                    ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        return example_weights, importance_scores

# Visualize routing with importance scores
routing_weights, importance_scores = visualize_routing_with_importance(
    sample_tokens,
    sample_states,
    router,
    state_manager,
    example_idx=0
)

## Batch-Level Routing Analysis

Let's analyze routing patterns across the entire batch to identify trends.

In [None]:
def analyze_batch_routing(tokens, states, router):
    """
    Analyze routing patterns across the entire batch.
    
    Args:
        tokens (torch.Tensor): Input tokens [batch_size, seq_len, token_dim]
        states (torch.Tensor): State vectors [batch_size, num_states, state_dim]
        router (TokenToStateRouter): Router module
    """
    with torch.no_grad():
        # Get routing weights
        _, routing_weights = router(tokens, states)
        
        # Convert to numpy
        routing_weights_np = routing_weights.cpu().numpy()
        
        # Compute average routing weights across batch
        avg_routing_weights = np.mean(routing_weights_np, axis=0)
        
        # Compute token-level preferences
        token_preferences = np.mean(avg_routing_weights, axis=0)  # Average across tokens
        
        # Compute state-level preferences
        state_preferences = np.mean(avg_routing_weights, axis=1)  # Average across states
        
        # Create visualization
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        
        # Average routing heatmap
        sns.heatmap(
            avg_routing_weights,
            annot=True,
            fmt='.2f',
            cmap='viridis',
            cbar=True,
            xticklabels=[f'State {i}' for i in range(avg_routing_weights.shape[1])],
            yticklabels=[f'Token {i}' for i in range(avg_routing_weights.shape[0])],
            ax=axes[0]
        )
        axes[0].set_title('Average Routing Weights (Across Batch)')
        axes[0].set_xlabel('State Nodes')
        axes[0].set_ylabel('Tokens')
        
        # Token preferences (which states are preferred overall)
        bars = axes[1].bar(range(len(token_preferences)), token_preferences, color='lightcoral')
        axes[1].set_xlabel('State Nodes')
        axes[1].set_ylabel('Average Attention')
        axes[1].set_title('State Preference Across All Tokens')
        axes[1].set_xticks(range(len(token_preferences)))
        axes[1].set_xticklabels([f'State {i}' for i in range(len(token_preferences))], rotation=45)
        
        # Add value labels on bars
        for i, bar in enumerate(bars):
            height = bar.get_height()
            axes[1].text(bar.get_x() + bar.get_width()/2., height,
                        f'{height:.2f}',
                        ha='center', va='bottom')
        
        # State preferences (which tokens attend to states most)
        bars = axes[2].bar(range(len(state_preferences)), state_preferences, color='lightgreen')
        axes[2].set_xlabel('Tokens')
        axes[2].set_ylabel('Average Attention')
        axes[2].set_title('Token Attention Distribution')
        axes[2].set_xticks(range(len(state_preferences)))
        axes[2].set_xticklabels([f'Token {i}' for i in range(len(state_preferences))], rotation=45)
        
        # Add value labels on bars
        for i, bar in enumerate(bars):
            height = bar.get_height()
            axes[2].text(bar.get_x() + bar.get_width()/2., height,
                        f'{height:.2f}',
                        ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        return avg_routing_weights, token_preferences, state_preferences

# Analyze batch routing
avg_weights, token_prefs, state_prefs = analyze_batch_routing(sample_tokens, sample_states, router)

## Visualize Routing Entropy

Let's analyze the entropy of routing patterns to understand how focused or diffuse the routing is.

In [None]:
def compute_entropy(probs):
    """
    Compute entropy of probability distributions.
    
    Args:
        probs (np.ndarray): Probability distributions
        
    Returns:
        np.ndarray: Entropy values
    """
    # Add small epsilon to avoid log(0)
    eps = 1e-8
    probs = np.clip(probs, eps, 1.0)
    return -np.sum(probs * np.log(probs), axis=-1)

def visualize_routing_entropy(tokens, states, router, example_idx=0):
    """
    Visualize the entropy of routing patterns.
    
    Args:
        tokens (torch.Tensor): Input tokens [batch_size, seq_len, token_dim]
        states (torch.Tensor): State vectors [batch_size, num_states, state_dim]
        router (TokenToStateRouter): Router module
        example_idx (int): Index of example in batch to visualize
    """
    with torch.no_grad():
        # Get routing weights
        _, routing_weights = router(tokens, states)
        
        # Select specific example
        example_weights = routing_weights[example_idx].cpu().numpy()
        
        # Compute entropy for each token
        entropies = compute_entropy(example_weights)
        
        # Create visualization
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
        
        # Entropy per token
        axes[0].plot(range(len(entropies)), entropies, marker='o', linewidth=2, markersize=6)
        axes[0].set_xlabel('Token Index')
        axes[0].set_ylabel('Routing Entropy')
        axes[0].set_title(f'Routing Entropy per Token (Example {example_idx})')
        axes[0].grid(True, alpha=0.3)
        
        # Histogram of entropies
        axes[1].hist(entropies, bins=10, color='skyblue', edgecolor='black', alpha=0.7)
        axes[1].set_xlabel('Routing Entropy')
        axes[1].set_ylabel('Frequency')
        axes[1].set_title('Distribution of Routing Entropies')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        print(f"Average routing entropy: {np.mean(entropies):.3f}")
        print(f"Min routing entropy: {np.min(entropies):.3f} (most focused routing)")
        print(f"Max routing entropy: {np.max(entropies):.3f} (most diffuse routing)")
        
        return entropies

# Visualize routing entropy
entropies = visualize_routing_entropy(sample_tokens, sample_states, router, example_idx=0)

## Compare Different Routing Heads

Let's visualize how different routing heads attend to different state nodes.

In [None]:
def visualize_routing_heads(tokens, states, router, example_idx=0):
    """
    Visualize routing patterns for individual heads.
    
    Args:
        tokens (torch.Tensor): Input tokens [batch_size, seq_len, token_dim]
        states (torch.Tensor): State vectors [batch_size, num_states, state_dim]
        router (TokenToStateRouter): Router module
        example_idx (int): Index of example in batch to visualize
    """
    with torch.no_grad():
        # Get raw routing logits
        batch_size, seq_len, token_dim = tokens.shape
        num_states = states.shape[1]
        
        # Compute routing weights
        routing_logits = router.router(tokens)  # [batch_size, seq_len, num_states * num_heads]
        routing_logits = routing_logits.view(batch_size, seq_len, router.num_heads, num_states)
        
        # Select specific example and apply softmax to each head
        example_logits = routing_logits[example_idx].cpu().numpy()  # [seq_len, num_heads, num_states]
        
        # Apply softmax to get probabilities for each head
        example_weights = np.zeros_like(example_logits)
        for head in range(router.num_heads):
            # Apply softmax to each token for this head
            for token in range(seq_len):
                example_weights[token, head, :] = np.exp(example_logits[token, head, :]) / np.sum(np.exp(example_logits[token, head, :]))
        
        # Create visualization for each head
        num_heads = router.num_heads
        fig, axes = plt.subplots(1, num_heads, figsize=(5*num_heads, 6))
        
        if num_heads == 1:
            axes = [axes]  # Make it iterable
            
        # Plot each head
        for head in range(num_heads):
            # Average across tokens for this head
            head_weights = example_weights[:, head, :]  # [seq_len, num_states]
            avg_head_weights = np.mean(head_weights, axis=0)  # [num_states]
            
            bars = axes[head].bar(range(len(avg_head_weights)), avg_head_weights, color=f'C{head}')
            axes[head].set_xlabel('State Nodes')
            axes[head].set_ylabel('Average Attention')
            axes[head].set_title(f'Routing Head {head+1}')
            axes[head].set_xticks(range(len(avg_head_weights)))
            axes[head].set_xticklabels([f'State {i}' for i in range(len(avg_head_weights))], rotation=45)
            
            # Add value labels on bars
            for i, bar in enumerate(bars):
                height = bar.get_height()
                axes[head].text(bar.get_x() + bar.get_width()/2., height,
                               f'{height:.2f}',
                               ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        return example_weights

# Visualize routing heads
head_weights = visualize_routing_heads(sample_tokens, sample_states, router, example_idx=0)

## Summary

This notebook demonstrated several ways to visualize token-to-state routing patterns in Neural State Machines:

1. **Routing Heatmaps**: Show which tokens attend to which states
2. **Importance Overlay**: Combine routing patterns with state importance scores
3. **Batch Analysis**: Analyze routing patterns across the entire batch
4. **Routing Entropy**: Measure how focused or diffuse the routing is
5. **Routing Heads**: Compare attention patterns across different routing heads

These visualizations help understand how information flows through the NSM architecture and can be useful for debugging and model analysis.