# 12 - Context and Memory Handling

This notebook covers techniques for handling long contexts and memory in language models.

## Topics Covered:
- Context length limitations
- Sliding window attention
- Long-context models
- Retrieval-augmented generation (RAG)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict

np.random.seed(42)

## 1. Context Length Limitations

In [None]:
class ContextAnalyzer:
    """Analyze context length limitations and solutions."""
    
    @staticmethod
    def attention_complexity(seq_len: int, d_model: int) -> Dict[str, float]:
        """Calculate attention complexity metrics."""
        # Memory complexity: O(n²d)
        memory_ops = seq_len ** 2 * d_model
        
        # Compute complexity: O(n²d)
        compute_ops = seq_len ** 2 * d_model
        
        # Memory in GB (assuming float32)
        memory_gb = memory_ops * 4 / (1024 ** 3)
        
        return {
            'memory_ops': memory_ops,
            'compute_ops': compute_ops,
            'memory_gb': memory_gb
        }
    
    @staticmethod
    def sliding_window_attention(x: np.ndarray, window_size: int) -> np.ndarray:
        """Implement sliding window attention."""
        batch_size, seq_len, d_model = x.shape
        
        # Create sliding window mask
        mask = np.zeros((seq_len, seq_len))
        for i in range(seq_len):
            start = max(0, i - window_size // 2)
            end = min(seq_len, i + window_size // 2 + 1)
            mask[i, start:end] = 1
        
        # Simplified attention computation
        attention_weights = np.random.rand(batch_size, seq_len, seq_len)
        attention_weights = attention_weights * mask[None, :, :]
        
        # Normalize
        attention_weights = attention_weights / (np.sum(attention_weights, axis=-1, keepdims=True) + 1e-10)
        
        # Apply attention
        output = attention_weights @ x
        
        return output

class RAGSystem:
    """Retrieval-Augmented Generation system."""
    
    def __init__(self, d_model: int = 256):
        self.d_model = d_model
        self.knowledge_base = []
        self.embeddings = []
    
    def add_document(self, text: str):
        """Add document to knowledge base."""
        # Simulate document embedding
        embedding = np.random.randn(self.d_model)
        embedding = embedding / np.linalg.norm(embedding)
        
        self.knowledge_base.append(text)
        self.embeddings.append(embedding)
    
    def retrieve(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]:
        """Retrieve relevant documents."""
        # Simulate query embedding
        query_embedding = np.random.randn(self.d_model)
        query_embedding = query_embedding / np.linalg.norm(query_embedding)
        
        # Calculate similarities
        similarities = []
        for i, doc_embedding in enumerate(self.embeddings):
            similarity = np.dot(query_embedding, doc_embedding)
            similarities.append((self.knowledge_base[i], similarity))
        
        # Sort by similarity and return top-k
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]
    
    def generate_with_retrieval(self, query: str, max_length: int = 50) -> str:
        """Generate response using retrieved context."""
        # Retrieve relevant documents
        retrieved_docs = self.retrieve(query)
        
        # Combine retrieved context
        context = " ".join([doc for doc, _ in retrieved_docs])
        
        # Simulate generation (simplified)
        response = f"Based on the context: {context[:100]}..., the answer is: [Generated response]"
        
        return response

def demonstrate_context_handling():
    """Demonstrate context and memory handling techniques."""
    
    print("Context and Memory Handling Analysis:")
    
    # Analyze context length limitations
    analyzer = ContextAnalyzer()
    
    context_lengths = [512, 1024, 2048, 4096, 8192, 16384]
    d_model = 768
    
    print("\nContext Length Complexity Analysis:")
    for seq_len in context_lengths:
        metrics = analyzer.attention_complexity(seq_len, d_model)
        print(f"  Length {seq_len}: {metrics['memory_gb']:.2f} GB memory")
    
    # Demonstrate sliding window attention
    batch_size, seq_len, d_model = 2, 16, 64
    x = np.random.randn(batch_size, seq_len, d_model)
    
    window_sizes = [4, 8, 12]
    print(f"\nSliding Window Attention (seq_len={seq_len}):")
    
    for window_size in window_sizes:
        output = analyzer.sliding_window_attention(x, window_size)
        print(f"  Window size {window_size}: Output shape {output.shape}")
    
    # Demonstrate RAG system
    rag = RAGSystem(d_model=128)
    
    # Add sample documents
    documents = [
        "The transformer architecture uses self-attention mechanisms.",
        "Large language models are trained on massive text corpora.",
        "Attention mechanisms allow models to focus on relevant information.",
        "RAG combines retrieval with generation for better factual accuracy.",
        "Context windows limit how much text a model can process at once."
    ]
    
    for doc in documents:
        rag.add_document(doc)
    
    # Test retrieval
    query = "How do attention mechanisms work?"
    retrieved = rag.retrieve(query, top_k=3)
    
    print(f"\nRAG System Demo:")
    print(f"Query: {query}")
    print(f"Retrieved documents:")
    for i, (doc, score) in enumerate(retrieved):
        print(f"  {i+1}. {doc} (score: {score:.3f})")
    
    # Generate response
    response = rag.generate_with_retrieval(query)
    print(f"\nGenerated response: {response}")
    
    # Visualizations
    plt.figure(figsize=(15, 10))
    
    # Context length vs memory usage
    plt.subplot(2, 3, 1)
    
    memory_usage = []
    for seq_len in context_lengths:
        metrics = analyzer.attention_complexity(seq_len, d_model)
        memory_usage.append(metrics['memory_gb'])
    
    plt.loglog(context_lengths, memory_usage, 'b-o', linewidth=2)
    plt.xlabel('Context Length')
    plt.ylabel('Memory Usage (GB)')
    plt.title('Quadratic Memory Growth')
    plt.grid(True, alpha=0.3)
    
    # Sliding window attention pattern
    plt.subplot(2, 3, 2)
    
    seq_len = 12
    window_size = 6
    
    mask = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        mask[i, start:end] = 1
    
    plt.imshow(mask, cmap='Blues', aspect='auto')
    plt.title(f'Sliding Window Attention\n(window={window_size})')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.colorbar()
    
    # Attention complexity comparison
    plt.subplot(2, 3, 3)
    
    methods = ['Full\nAttention', 'Sliding\nWindow', 'Sparse\nAttention']
    
    # Complexity for seq_len = 4096
    seq_len = 4096
    full_complexity = seq_len ** 2
    window_complexity = seq_len * 256  # Assuming window size 256
    sparse_complexity = seq_len * 64   # Assuming 64 connections per token
    
    complexities = [full_complexity, window_complexity, sparse_complexity]
    
    plt.bar(methods, complexities, alpha=0.7)
    plt.ylabel('Computational Complexity')
    plt.title('Attention Complexity Comparison')
    plt.yscale('log')
    
    # RAG retrieval similarity scores
    plt.subplot(2, 3, 4)
    
    doc_names = [f'Doc {i+1}' for i in range(len(documents))]
    similarities = [score for _, score in rag.retrieve(query, top_k=len(documents))]
    
    plt.bar(doc_names, similarities, alpha=0.7)
    plt.xlabel('Documents')
    plt.ylabel('Similarity Score')
    plt.title('RAG Retrieval Scores')
    plt.xticks(rotation=45)
    
    # Context window utilization
    plt.subplot(2, 3, 5)
    
    # Simulate context utilization over time
    time_steps = np.arange(0, 100, 5)
    context_usage = 50 + 30 * np.sin(time_steps * 0.1) + np.random.randn(len(time_steps)) * 5
    context_usage = np.clip(context_usage, 0, 100)
    
    plt.plot(time_steps, context_usage, 'g-', linewidth=2)
    plt.axhline(y=80, color='r', linestyle='--', label='Context Limit')
    plt.xlabel('Time Steps')
    plt.ylabel('Context Usage (%)')
    plt.title('Dynamic Context Usage')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Memory vs performance trade-off
    plt.subplot(2, 3, 6)
    
    techniques = ['Full Context', 'Sliding Window', 'Hierarchical', 'RAG']
    memory_efficiency = [20, 80, 70, 90]  # Higher = more efficient
    performance = [100, 85, 90, 95]       # Performance retention
    
    plt.scatter(memory_efficiency, performance, s=100, alpha=0.7)
    
    for i, technique in enumerate(techniques):
        plt.annotate(technique, (memory_efficiency[i], performance[i]), 
                    xytext=(5, 5), textcoords='offset points')
    
    plt.xlabel('Memory Efficiency')
    plt.ylabel('Performance Retention')
    plt.title('Memory vs Performance Trade-off')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nContext Handling Insights:")
    
    print("\nContext Length Limitations:")
    print("  - Quadratic memory growth with sequence length")
    print("  - GPU memory becomes bottleneck for long sequences")
    print("  - Training becomes prohibitively expensive")
    
    print("\nSliding Window Attention:")
    print("  + Linear complexity instead of quadratic")
    print("  + Can process arbitrarily long sequences")
    print("  - Limited long-range dependencies")
    
    print("\nRAG Systems:")
    print("  + Access to external knowledge")
    print("  + Better factual accuracy")
    print("  + Updatable knowledge base")
    print("  - Additional retrieval latency")
    print("  - Requires high-quality embeddings")

demonstrate_context_handling()