# Getting Started with Neural State Machines

This notebook provides a quick introduction to Neural State Machines (NSM) and how to use them for sequence processing tasks.

## Installation

First, let's make sure we have the required packages installed:

In [None]:
!pip install torch numpy matplotlib seaborn

## Basic Usage

Let's start with the core components of NSM:

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Import NSM components
from nsm import StatePropagator

print("✅ Neural State Machine components imported successfully!")

## 1. State Propagation

The core of NSM is the `StatePropagator` which handles how states are updated, retained, or reset:

In [None]:
# Create a state propagator
state_dim = 64
propagator = StatePropagator(
    state_dim=state_dim,
    gate_type='gru',  # or 'lstm'
    enable_communication=False  # Disable for now
)

print(f"State propagator created with state dimension: {state_dim}")
print(f"Gate type: {'GRU' if propagator.gate_type == 'gru' else 'LSTM'}")

In [None]:
# Single state update example
batch_size = 4
prev_state = torch.randn(batch_size, state_dim)
new_input = torch.randn(batch_size, state_dim)

print(f"Previous state shape: {prev_state.shape}")
print(f"New input shape: {new_input.shape}")

# Apply state update
updated_state = propagator(prev_state, new_input)

print(f"Updated state shape: {updated_state.shape}")

# Analyze the update
state_change = torch.mean(torch.abs(updated_state - prev_state))
print(f"Average state change: {state_change.item():.4f}")
print(f"Previous state norm: {torch.mean(torch.norm(prev_state, dim=1)):.4f}")
print(f"Updated state norm: {torch.mean(torch.norm(updated_state, dim=1)):.4f}")

## 2. Multi-State Processing

NSM can process multiple states simultaneously with communication between them:

In [None]:
# Create multi-state propagator with communication
num_states = 8
multi_propagator = StatePropagator(
    state_dim=state_dim,
    gate_type='gru',
    num_heads=4,
    enable_communication=True  # Enable state-to-state communication
)

print(f"Multi-state propagator created with {num_states} states")
print(f"Communication enabled: {multi_propagator.enable_communication}")

In [None]:
# Multi-state update example
prev_states = torch.randn(batch_size, num_states, state_dim)
new_inputs = torch.randn(batch_size, num_states, state_dim)

print(f"Previous states shape: {prev_states.shape}")
print(f"New inputs shape: {new_inputs.shape}")

# Apply multi-state update
updated_states = multi_propagator(prev_states, new_inputs)

print(f"Updated states shape: {updated_states.shape}")

# Analyze the multi-state update
state_changes = torch.mean(torch.abs(updated_states - prev_states), dim=(1, 2))
print(f"Average state change per batch: {torch.mean(state_changes).item():.4f}")
print(f"Max state change per batch: {torch.max(state_changes).item():.4f}")
print(f"Min state change per batch: {torch.min(state_changes).item():.4f}")

## 3. Complete NSM Model

Let's create a complete NSM model for a simple classification task:

In [None]:
from nsm.models import SimpleNSM

# Create a simple NSM model for MNIST-like classification
model = SimpleNSM(
    input_dim=784,      # Flattened 28x28 images
    state_dim=128,      # State vector dimension
    num_states=16,      # Number of state nodes
    output_dim=10,      # 10-class classification
    gate_type='gru'     # Gating mechanism
)

print("✅ Simple NSM model created!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Test the model with sample data
sample_batch_size = 8
x = torch.randn(sample_batch_size, 784)  # Random input (like flattened MNIST)

print(f"Input shape: {x.shape}")

# Forward pass
output = model(x)

print(f"Output shape: {output.shape}")
print(f"Output range: [{output.min().item():.3f}, {output.max().item():.3f}]")
print(f"Output probabilities sum (first sample): {torch.softmax(output[0], dim=0).sum().item():.3f}")

## 4. Training Example

Let's train our NSM model on a simple synthetic dataset:

In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Generate synthetic classification data
X, y = make_classification(n_samples=1000, n_features=784, n_classes=10, 
                           n_informative=200, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert to PyTorch tensors
X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.LongTensor(y_test)

# Create data loaders
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Feature dimension: {X_train.shape[1]}")
print(f"Number of classes: {len(torch.unique(y_train))}")

In [None]:
# Setup training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)

In [None]:
# Training function
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)
        
        # Limit training for demo
        if batch_idx > 10:  # Just for demonstration
            break
    
    avg_loss = total_loss / (batch_idx + 1)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

In [None]:
# Evaluation function
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
            # Limit evaluation for demo
            if batch_idx > 5:  # Just for demonstration
                break
    
    avg_loss = total_loss / (batch_idx + 1)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

In [None]:
# Quick training demo
epochs = 5
print("Starting training demo...")
print("="*50)

for epoch in range(epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    scheduler.step()
    
    print(f'Epoch {epoch+1}/{epochs}:')
    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    print()

print("🎉 Training demo completed!")

## 5. Visualization and Analysis

Let's visualize some aspects of our NSM model:

In [None]:
# Analyze model outputs
model.eval()
with torch.no_grad():
    sample_input = X_test[:4].to(device)  # First 4 test samples
    sample_output = model(sample_input)
    sample_probs = torch.softmax(sample_output, dim=1)

print("Sample predictions:")
for i in range(4):
    predicted_class = sample_probs[i].argmax().item()
    confidence = sample_probs[i].max().item()
    actual_class = y_test[i].item()
    print(f"  Sample {i+1}: Predicted={predicted_class} (confidence={confidence:.3f}), Actual={actual_class}, "
          f"Correct={'✅' if predicted_class == actual_class else '❌'}")

In [None]:
# Visualize probability distributions
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.ravel()

for i in range(4):
    axes[i].bar(range(10), sample_probs[i].cpu().numpy())
    axes[i].set_xlabel('Class')
    axes[i].set_ylabel('Probability')
    axes[i].set_title(f'Sample {i+1} - Actual Class: {y_test[i].item()}')
    axes[i].set_ylim(0, 1)
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Advanced Features

NSM also supports advanced features like dynamic state management:

In [None]:
from nsm import StateManager

# Create state manager with dynamic allocation
state_manager = StateManager(
    state_dim=128,
    max_states=32,
    initial_states=16,
    prune_threshold=0.3
)

print("✅ State manager created with dynamic allocation!")
print(f"Initial active states: {state_manager.get_active_count()}")

# Get current states
states = state_manager()
print(f"Current states shape: {states.shape}")

# Get importance scores
importance_scores = state_manager.get_importance_scores()
active_count = state_manager.get_active_count()
print(f"Average importance score: {torch.mean(importance_scores[:active_count]).item():.3f}")

In [None]:
# Simulate dynamic state management
print("\nDynamic state management demo:")
print("="*40)

for step in range(5):
    # Simulate some processing that affects state importance
    # (In real training, this would happen through backpropagation)
    
    current_count = state_manager.get_active_count()
    print(f"Step {step+1}: Active states = {current_count}")
    
    # Periodically manage states
    if step % 2 == 0 and step > 0:
        pruned = state_manager.prune_low_importance_states()
        allocated = state_manager.allocate_states(2)
        print(f"  Pruned: {pruned}, Allocated: {allocated}")
        print(f"  New active count: {state_manager.get_active_count()}")

final_count = state_manager.get_active_count()
print(f"\nFinal active states: {final_count}")

## Summary

In this notebook, we've covered:

✅ **Basic state propagation** with gated updates
✅ **Multi-state processing** with communication
✅ **Complete NSM models** for classification tasks
✅ **Training workflows** with optimization
✅ **Visualization and analysis** of model behavior
✅ **Advanced features** like dynamic state management

Neural State Machines provide a powerful alternative to traditional Transformers with:

- **Efficient computation**: Linear scaling with sequence length
- **Interpretability**: Explicit state management and tracking
- **Flexibility**: Dynamic state allocation and pruning
- **Performance**: Competitive results with reduced resource usage

For more advanced usage and detailed documentation, please refer to:
- [Architecture Overview](../docs/architecture/architecture_overview.md)
- [API Reference](../docs/api/api_reference.md)
- [Full Tutorial](../docs/tutorials/tutorial.md)

Happy experimenting with Neural State Machines! 🚀