# Tutorial: Hybrid Attention in NSM

This tutorial explains the **Hybrid Attention** mechanism in Neural State Machines (NSM), which combines local (token-token) and global (token-state) attention for comprehensive context modeling.

## 🎯 Learning Objectives

By the end of this tutorial, you will understand:

1. The concept of hybrid attention in NSM
2. How local and global attention complement each other
3. The benefits of combining both attention types
4. A simple implementation of hybrid attention

## 🧠 What is Hybrid Attention?

Traditional Transformers rely solely on token-token attention, which can miss global context. NSM introduces **hybrid attention** that combines:

- **Local Attention**: Captures immediate, token-level context
- **Global Attention**: Captures long-term, state-level context

This dual approach enables the model to:
- Understand fine-grained details (local)
- Maintain big-picture understanding (global)
- Balance efficiency with expressivity

## 🔍 How Hybrid Attention Works

The hybrid attention process involves:

1. **Local Attention Computation**: Standard self-attention between tokens
2. **Global Attention Computation**: Attention from tokens to state nodes
3. **Context Combination**: Merging local and global contexts

Let's implement a simple example to see this in action.

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# For better visualization
import seaborn as sns
sns.set(style="whitegrid")

print("Libraries imported successfully!")

In [None]:
# Define parameters
batch_size = 1
seq_length = 6  # Number of tokens
token_dim = 16
num_states = 4
state_dim = 16

# Create random token embeddings
tokens = torch.randn(batch_size, seq_length, token_dim)

# Create random state embeddings
states = torch.randn(batch_size, num_states, state_dim)

# Create routing weights (from previous tutorial)
routing_weights = F.softmax(torch.bmm(tokens, states.transpose(1, 2)), dim=-1)

print(f"Token embeddings shape: {tokens.shape}")
print(f"State embeddings shape: {states.shape}")
print(f"Routing weights shape: {routing_weights.shape}")

In [None]:
# Local Attention (Token-Token)

# Initialize a simple multi-head attention mechanism
local_attn = nn.MultiheadAttention(token_dim, num_heads=2, batch_first=True)

# Compute local attention
local_context, local_weights = local_attn(tokens, tokens, tokens)

print(f"Local context shape: {local_context.shape}")
print(f"Local attention weights shape: {local_weights.shape}")

# Visualize local attention for first token
plt.figure(figsize=(8, 4))
sns.heatmap(local_weights[0].detach().numpy(), annot=True, cmap="YlGnBu", fmt=".2f",
            xticklabels=[f'Token {i}' for i in range(seq_length)],
            yticklabels=[f'Token {i}' for i in range(seq_length)])
plt.title('Local Attention (Token-Token)')
plt.xlabel('Key Tokens')
plt.ylabel('Query Tokens')
plt.show()

In [None]:
# Global Attention (Token-State)

# Compute global context using routing weights
# [B, N, S] x [B, S, D_state] -> [B, N, D_state]
global_context = torch.bmm(routing_weights, states)

print(f"Global context shape: {global_context.shape}")

# Visualize global attention (routing weights)
plt.figure(figsize=(8, 4))
sns.heatmap(routing_weights[0].detach().numpy(), annot=True, cmap="YlOrRd", fmt=".2f",
            xticklabels=[f'State {i}' for i in range(num_states)],
            yticklabels=[f'Token {i}' for i in range(seq_length)])
plt.title('Global Attention (Token-State)')
plt.xlabel('State Nodes')
plt.ylabel('Tokens')
plt.show()

In [None]:
# Combine Local and Global Context

# Simple combination: weighted sum
# In practice, this could be a more complex gating mechanism
alpha = 0.5  # Balance parameter
combined_context = alpha * local_context + (1 - alpha) * global_context

print(f"Combined context shape: {combined_context.shape}")

print("\nComparison of contexts for first token:")
print(f"Original token: {tokens[0, 0]}")
print(f"Local context: {local_context[0, 0]}")
print(f"Global context: {global_context[0, 0]}")
print(f"Combined context: {combined_context[0, 0]}")

## 🎓 Key Takeaways

1. **Dual Context Modeling**: Hybrid attention captures both immediate (local) and long-term (global) context.
2. **Complementary Information**: Local attention handles fine-grained details, while global attention maintains big-picture understanding.
3. **Flexibility**: The combination mechanism can be learned or adjusted for different tasks.

## 🚀 Next Steps

With these foundational concepts understood, you're ready to explore the full NSM architecture.

See `notebooks/research/nsm_prototype.ipynb` for a complete implementation of an NSM layer.

For a deeper dive into research aspects, check out:
- `notebooks/research/benchmarking.ipynb`: Performance comparisons
- `notebooks/research/interpretability.ipynb`: Visualization of state evolution