# Neural State Machine (NSM) Prototype

This notebook implements a basic prototype of the Neural State Machine (NSM) layer, demonstrating the core components and their interactions.

## Objective

To implement and test a minimal NSM layer that includes:

1. **State Management**
2. **Token-to-State Routing**
3. **State Propagation**
4. **Hybrid Attention**

## Architecture Overview

```
Input Tokens → Local Attention → Token-to-State Interaction → Updated States  
↘ State-to-State Propagation ↙
```

## Implementation Details

- **Framework**: PyTorch
- **State Nodes**: Fixed number of learnable memory slots
- **Routing**: Soft attention mechanism
- **Propagation**: Simple recurrent update
- **Attention**: Combination of local and global mechanisms

Let's start by implementing the core components.

In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# For better visualization
import seaborn as sns
sns.set(style="whitegrid")

print("Libraries imported successfully!")

Libraries imported successfully!


In [3]:
class StateManager(nn.Module):
    """
    State Manager for Neural State Machines (NSM).
    Manages a fixed number of state vectors that evolve over time.
    Each state vector has a dimension of D.
    For a batch of size B, it manages [B, S, D] tensors.
    """
    def __init__(self, num_states, state_dim):
        super(StateManager, self).__init__()
        self.num_states = num_states
        self.state_dim = state_dim

    def forward(self, initial_states=None, batch_size=None):
        """
        Returns initial states.
        Args:
            initial_states (torch.Tensor, optional): Predefined initial states of shape [B, S, D].
            batch_size (int, optional): Batch size to create initial states if not provided.
        Returns:
            torch.Tensor: Initial states of shape [B, S, D].
        """
        if initial_states is not None:
            return initial_states
        elif batch_size is not None:
            # Initialize with zeros
            return torch.zeros(batch_size, self.num_states, self.state_dim)
        else:
            raise ValueError("Either initial_states or batch_size must be provided.")

In [4]:
class TokenToStateRouter(nn.Module):
    """
    Routes token embeddings to relevant state nodes.
    Computes attention weights between tokens and states to determine routing.
    """
    def __init__(self, token_dim, state_dim, use_softmax=True):
        super(TokenToStateRouter, self).__init__()
        self.token_dim = token_dim
        self.state_dim = state_dim
        self.use_softmax = use_softmax
        # A linear layer to project token embeddings to state dimension for compatibility
        self.token_to_state_proj = nn.Linear(token_dim, state_dim)

    def forward(self, token_embeddings, state_embeddings):
        """
        Compute routing weights from tokens to states.
        Args:
            token_embeddings (torch.Tensor): [B, N, D_token]
            state_embeddings (torch.Tensor): [B, S, D_state]
        Returns:
            torch.Tensor: Routing weights of shape [B, N, S]
        """
        # Project tokens to state dimension
        projected_tokens = self.token_to_state_proj(token_embeddings)  # [B, N, D_state]
        # Compute compatibility scores (dot product)
        # [B, N, D_state] x [B, D_state, S] -> [B, N, S]
        compatibility = torch.bmm(projected_tokens, state_embeddings.transpose(1, 2))
        if self.use_softmax:
            routing_weights = F.softmax(compatibility, dim=-1)
        else:
            routing_weights = compatibility
        return routing_weights

In [5]:
class StatePropagator(nn.Module):
    """
    Updates state embeddings across layers.
    Implements a simple recurrent update mechanism for state nodes.
    """
    def __init__(self, state_dim, hidden_dim=None):
        super(StatePropagator, self).__init__()
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim or state_dim
        # For simplicity, we'll use a linear transformation for update
        self.update_layer = nn.Linear(state_dim, state_dim)

    def forward(self, state_embeddings):
        """
        Update state embeddings for the next layer.
        Args:
            state_embeddings (torch.Tensor): [B, S, D_state]
        Returns:
            torch.Tensor: Updated state embeddings of shape [B, S, D_state]
        """
        # Simple update: add a transformed version of the current states
        update = self.update_layer(state_embeddings)
        new_states = state_embeddings + update
        return new_states

In [6]:
class HybridAttention(nn.Module):
    """
    Combines local (token-token) and global (token-state) attention mechanisms.
    """
    def __init__(self, token_dim):
        super(HybridAttention, self).__init__()
        self.token_dim = token_dim
        # Local attention mechanism (standard self-attention)
        self.local_attn = nn.MultiheadAttention(token_dim, num_heads=4, batch_first=True)

    def forward(self, token_embeddings, state_embeddings, routing_weights):
        """
        Apply hybrid attention.
        Args:
            token_embeddings (torch.Tensor): [B, N, D_token]
            state_embeddings (torch.Tensor): [B, S, D_state]
            routing_weights (torch.Tensor): [B, N, S] from TokenToStateRouter
        Returns:
            torch.Tensor: Context-enriched token embeddings of shape [B, N, D_token]
        """
        # 1. Local Attention (token-token)
        local_context, _ = self.local_attn(token_embeddings, token_embeddings, token_embeddings)
        # 2. Global Attention (token-state)
        # Use routing weights to aggregate state information
        # [B, N, S] x [B, S, D_state] -> [B, N, D_state]
        global_context = torch.bmm(routing_weights, state_embeddings)
        # 3. Combine local and global context
        # For simplicity, we'll sum them (assuming same dimensions)
        if local_context.size(-1) == global_context.size(-1):
            output = local_context + global_context
        else:
            # If dimensions differ, project to token_dim
            if not hasattr(self, 'context_proj'):
                self.context_proj = nn.Linear(local_context.size(-1) + global_context.size(-1), self.token_dim)
            combined_context = torch.cat([local_context, global_context], dim=-1)
            output = self.context_proj(combined_context)
        return output

In [9]:
class NSMLayer(nn.Module):
    """
    A single layer of the Neural State Machine (NSM).
    Integrates all components: State Management, Routing, Propagation, and Hybrid Attention.
    """
    def __init__(self, num_states, state_dim, token_dim):
        super(NSMLayer, self).__init__()
        self.num_states = num_states
        self.state_dim = state_dim
        self.token_dim = token_dim
        # Initialize components
        self.state_manager = StateManager(num_states, state_dim)
        self.router = TokenToStateRouter(token_dim, state_dim)
        self.propagator = StatePropagator(state_dim)
        self.hybrid_attn = HybridAttention(token_dim)
        # A layer norm for stability
        self.layer_norm = nn.LayerNorm(token_dim)

    def forward(self, token_embeddings, state_embeddings=None):
        """
        Forward pass of the NSM layer.
        Args:
            token_embeddings (torch.Tensor): [B, N, D_token]
            state_embeddings (torch.Tensor, optional): [B, S, D_state]. If None, initialized.
        Returns:
            tuple: (updated_token_embeddings [B, N, D_token], updated_state_embeddings [B, S, D_state])
        """
        batch_size = token_embeddings.size(0)
        # 1. Initialize or use provided state embeddings
        if state_embeddings is None:
            state_embeddings = self.state_manager(batch_size=batch_size)
        # 2. Token-to-State Routing
        routing_weights = self.router(token_embeddings, state_embeddings)
        # 3. Hybrid Attention
        attended_tokens = self.hybrid_attn(token_embeddings, state_embeddings, routing_weights)
        # 4. Residual connection and layer normalization for tokens
        updated_tokens = self.layer_norm(token_embeddings + attended_tokens)
        # 5. State Propagation
        updated_states = self.propagator(state_embeddings)
        return updated_tokens, updated_states

In [10]:
# Test the NSM Layer

# Parameters
batch_size = 2
seq_length = 10
token_dim = 32
num_states = 5
state_dim = 32

# Create random token embeddings
tokens = torch.randn(batch_size, seq_length, token_dim)

# Initialize NSM Layer
nsm_layer = NSMLayer(num_states, state_dim, token_dim)

# Forward pass
updated_tokens, updated_states = nsm_layer(tokens)

print(f"Input tokens shape: {tokens.shape}")
print(f"Updated tokens shape: {updated_tokens.shape}")
print(f"Updated states shape: {updated_states.shape}")

# Check if shapes are as expected
assert updated_tokens.shape == tokens.shape, "Token shapes do not match"
assert updated_states.shape == (batch_size, num_states, state_dim), "State shapes do not match"

print("\n✅ NSM Layer test passed!")

Input tokens shape: torch.Size([2, 10, 32])
Updated tokens shape: torch.Size([2, 10, 32])
Updated states shape: torch.Size([2, 5, 32])

✅ NSM Layer test passed!


## Next Steps

This prototype demonstrates the core components of an NSM layer. In the next steps, we will:

1. **Stack Multiple NSM Layers** to create a deeper architecture
2. **Integrate with a Classification Head** for downstream tasks
3. **Train on Simple Datasets** like MNIST to validate performance
4. **Compare with Baseline Models** (Transformer, LSTM, etc.)

See `notebooks/research/benchmarking.ipynb` for the next phase of research.