# Redefining Context in Transformers: The Role of Neural Attention Memory Models

This notebook demonstrates the key concepts and implementation of Neural Attention Memory Models (NAMMs) in transformer architectures. We'll explore how NAMMs improve context management and memory handling in transformer models through practical examples and visualizations.

## Setup
First, let's import the required libraries:

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

## Traditional Transformer Context Management

Let's implement the basic scaled dot-product attention mechanism used in traditional transformers:

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    """
    matmul_qk = tf.matmul(q, k, transpose_b=True)
    
    # Scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # Add the mask if provided
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)

    # Softmax is normalized on the last axis (seq_len_k)
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
    output = tf.matmul(attention_weights, v)

    return output, attention_weights

## Neural Attention Memory Model (NAMM) Implementation

Now let's implement a basic NAMM layer that extends the traditional attention mechanism with memory management:

In [None]:
class NAMM(tf.keras.layers.Layer):
    def __init__(self, memory_size, embedding_dim, **kwargs):
        super(NAMM, self).__init__(**kwargs)
        self.memory_size = memory_size
        self.embedding_dim = embedding_dim
        
        # Initialize memory
        self.memory = tf.Variable(
            initial_value=tf.random.normal((1, memory_size, embedding_dim)),
            trainable=True,
            name='memory'
        )
        
    def call(self, inputs):
        # Combine input with memory
        batch_size = tf.shape(inputs)[0]
        memory_batch = tf.tile(self.memory, [batch_size, 1, 1])
        combined = tf.concat([memory_batch, inputs], axis=1)
        
        # Apply attention
        output, attention_weights = scaled_dot_product_attention(
            combined, combined, combined
        )
        
        # Update memory
        new_memory = output[:, :self.memory_size, :]
        self.memory.assign(tf.reduce_mean(new_memory, axis=0, keepdims=True))
        
        return output[:, self.memory_size:, :]