# Long Context Architectures for LLMs

## Overview

Handling long sequences efficiently is crucial for many LLM applications. This notebook covers:

- **Memory-Augmented Transformers**: External memory mechanisms
- **Hierarchical Attention**: Multi-scale attention patterns
- **Recurrent Memory**: Integrating recurrence with transformers
- **Context Compression**: Efficient long-range modeling

Let's implement architectures that can handle extended contexts efficiently.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
from typing import Optional, Tuple, List

print("Libraries imported successfully!")

## 1. Memory-Augmented Transformer

External memory allows the model to store and retrieve information beyond the immediate context:

In [None]:
class MemoryAugmentedTransformer(nn.Module):
    """Transformer with external memory for long contexts"""
    
    def __init__(self, d_model, n_heads, memory_size=1024, memory_dim=None):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.memory_size = memory_size
        self.memory_dim = memory_dim or d_model
        
        # External memory
        self.memory = nn.Parameter(torch.randn(memory_size, self.memory_dim))
        
        # Memory access mechanisms
        self.memory_query = nn.Linear(d_model, self.memory_dim)
        self.memory_key = nn.Linear(self.memory_dim, self.memory_dim)
        self.memory_value = nn.Linear(self.memory_dim, d_model)
        
        # Standard attention
        self.self_attention = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        
        # Memory update mechanism
        self.memory_update = nn.Linear(d_model, self.memory_dim)
        self.memory_gate = nn.Linear(d_model + self.memory_dim, 1)
        
        # Layer norm and feedforward
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model)
        )
    
    def forward(self, x, memory_state=None):
        B, T, D = x.shape
        
        # Use provided memory state or initialize
        if memory_state is None:
            current_memory = self.memory.unsqueeze(0).expand(B, -1, -1)
        else:
            current_memory = memory_state
        
        # Self-attention
        attn_out, _ = self.self_attention(x, x, x)
        x = self.norm1(x + attn_out)
        
        # Memory attention
        memory_out = self.memory_attention(x, current_memory)
        x = self.norm2(x + memory_out)
        
        # Update memory
        updated_memory = self.update_memory(x, current_memory)
        
        # Feedforward
        ffn_out = self.ffn(x)
        x = self.norm3(x + ffn_out)
        
        return x, updated_memory
    
    def memory_attention(self, queries, memory):
        """Attend to external memory"""
        B, T, D = queries.shape
        M, MD = memory.shape[1:]
        
        # Project queries
        q = self.memory_query(queries)  # [B, T, MD]
        k = self.memory_key(memory)     # [B, M, MD]
        v = self.memory_value(memory)   # [B, M, D]
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(MD)
        attn_weights = F.softmax(scores, dim=-1)
        
        # Apply attention
        memory_out = torch.matmul(attn_weights, v)
        
        return memory_out
    
    def update_memory(self, x, memory):
        """Update memory based on current input"""
        B, T, D = x.shape
        
        # Compute memory updates
        memory_updates = self.memory_update(x)  # [B, T, MD]
        
        # Compute attention weights for memory update
        update_queries = memory_updates.mean(dim=1, keepdim=True)  # [B, 1, MD]
        memory_keys = self.memory_key(memory)  # [B, M, MD]
        
        update_scores = torch.matmul(update_queries, memory_keys.transpose(-2, -1))
        update_weights = F.softmax(update_scores / math.sqrt(self.memory_dim), dim=-1)
        
        # Compute gates for selective update
        gate_input = torch.cat([update_queries.expand(-1, self.memory_size, -1), memory], dim=-1)
        gates = torch.sigmoid(self.memory_gate(gate_input))  # [B, M, 1]
        
        # Update memory
        memory_delta = update_weights.transpose(-2, -1) * memory_updates.mean(dim=1, keepdim=True)
        updated_memory = memory + gates * memory_delta
        
        return updated_memory

class HierarchicalAttention(nn.Module):
    """Multi-scale hierarchical attention mechanism"""
    
    def __init__(self, d_model, n_heads, scales=[1, 4, 16]):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.scales = scales
        self.d_head = d_model // n_heads
        
        # Multi-scale projections
        self.scale_projections = nn.ModuleList([
            nn.ModuleDict({
                'q': nn.Linear(d_model, d_model),
                'k': nn.Linear(d_model, d_model),
                'v': nn.Linear(d_model, d_model)
            }) for _ in scales
        ])
        
        # Scale fusion
        self.scale_fusion = nn.Linear(len(scales) * d_model, d_model)
        
        # Output projection
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, D = x.shape
        scale_outputs = []
        
        for scale_idx, scale in enumerate(self.scales):
            # Downsample for this scale
            if scale > 1:
                # Average pooling for downsampling
                x_scaled = F.avg_pool1d(
                    x.transpose(1, 2), 
                    kernel_size=scale, 
                    stride=scale,
                    padding=0
                ).transpose(1, 2)
            else:
                x_scaled = x
            
            # Apply attention at this scale
            scale_out = self.scale_attention(x_scaled, scale_idx)
            
            # Upsample back to original resolution
            if scale > 1:
                scale_out = F.interpolate(
                    scale_out.transpose(1, 2),
                    size=T,
                    mode='linear',
                    align_corners=False
                ).transpose(1, 2)
            
            scale_outputs.append(scale_out)
        
        # Fuse multi-scale outputs
        fused = torch.cat(scale_outputs, dim=-1)
        output = self.scale_fusion(fused)
        
        return self.out_proj(output)
    
    def scale_attention(self, x, scale_idx):
        """Apply attention at specific scale"""
        B, T, D = x.shape
        
        # Project Q, K, V
        proj = self.scale_projections[scale_idx]
        q = proj['q'](x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = proj['k'](x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = proj['v'](x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        
        # Compute attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        
        # Reshape output
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        
        return out

print("Long context architecture modules implemented!")

### Testing Long Context Architectures

Let's test memory retention and hierarchical processing:

In [None]:
def test_memory_retention(model, seq_length=2048, memory_test_positions=[100, 500, 1000, 1500]):
    """Test how well the model retains information across long sequences"""
    d_model = model.d_model
    
    # Create test sequence with special tokens at specific positions
    x = torch.randn(1, seq_length, d_model)
    
    # Insert distinctive patterns at test positions
    for pos in memory_test_positions:
        if pos < seq_length:
            # Create a distinctive pattern
            pattern = torch.ones(1, 1, d_model) * (pos / seq_length)
            x[:, pos:pos+1, :] = pattern
    
    # Process through model
    memory_state = None
    chunk_size = 256
    outputs = []
    
    for i in range(0, seq_length, chunk_size):
        end_idx = min(i + chunk_size, seq_length)
        chunk = x[:, i:end_idx, :]
        
        if hasattr(model, 'memory'):
            output, memory_state = model(chunk, memory_state)
        else:
            output = model(chunk)
        
        outputs.append(output)
    
    # Analyze memory retention
    full_output = torch.cat(outputs, dim=1)
    retention_scores = []
    
    for pos in memory_test_positions:
        if pos < seq_length - 100:  # Ensure we have space to test retention
            # Compare pattern at insertion vs later positions
            original_pattern = x[:, pos, :]
            
            # Test retention at positions after the pattern
            test_positions = [pos + 50, pos + 100, pos + 200]
            
            for test_pos in test_positions:
                if test_pos < seq_length:
                    output_at_test = full_output[:, test_pos, :]
                    
                    # Measure similarity (simplified)
                    similarity = F.cosine_similarity(
                        original_pattern.unsqueeze(0), 
                        output_at_test.unsqueeze(0)
                    ).item()
                    
                    retention_scores.append({
                        'insert_pos': pos,
                        'test_pos': test_pos,
                        'distance': test_pos - pos,
                        'similarity': similarity
                    })
    
    return retention_scores

def analyze_hierarchical_patterns(model, seq_length=1024):
    """Analyze attention patterns at different scales"""
    if not hasattr(model, 'scales'):
        return None
    
    d_model = model.d_model
    x = torch.randn(1, seq_length, d_model)
    
    # Hook to capture attention weights
    attention_weights = {}
    
    def hook_fn(name):
        def hook(module, input, output):
            if hasattr(module, 'last_attention_weights'):
                attention_weights[name] = module.last_attention_weights
        return hook
    
    # Register hooks (simplified for demo)
    with torch.no_grad():
        output = model(x)
    
    return {
        'output_shape': output.shape,
        'scales_processed': model.scales,
        'hierarchical_processing': True
    }

# Test models
d_model, n_heads = 256, 8

print("Testing Memory-Augmented Transformer...")
memory_model = MemoryAugmentedTransformer(d_model, n_heads, memory_size=512)
memory_retention = test_memory_retention(memory_model, seq_length=1024)

print("Testing Hierarchical Attention...")
hierarchical_model = HierarchicalAttention(d_model, n_heads, scales=[1, 2, 4, 8])
hierarchical_analysis = analyze_hierarchical_patterns(hierarchical_model, seq_length=512)

print("Testing Standard Transformer (baseline)...")
standard_model = nn.TransformerEncoderLayer(d_model, n_heads, batch_first=True)

# Benchmark memory efficiency
def benchmark_memory_efficiency(models, seq_lengths):
    """Benchmark memory usage across different sequence lengths"""
    results = {}
    
    for name, model in models.items():
        results[name] = []
        
        for seq_len in seq_lengths:
            x = torch.randn(2, seq_len, d_model)
            
            # Measure memory before
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
            
            try:
                with torch.no_grad():
                    if hasattr(model, 'memory'):
                        output, _ = model(x)
                    else:
                        output = model(x)
                
                memory_used = torch.cuda.max_memory_allocated() if torch.cuda.is_available() else 0
                
                results[name].append({
                    'seq_len': seq_len,
                    'memory_mb': memory_used / (1024 * 1024),
                    'success': True
                })
                
            except RuntimeError as e:
                results[name].append({
                    'seq_len': seq_len,
                    'memory_mb': float('inf'),
                    'success': False,
                    'error': str(e)
                })
    
    return results

models = {
    'Standard': standard_model,
    'Memory-Augmented': memory_model,
    'Hierarchical': hierarchical_model
}

seq_lengths = [128, 256, 512, 1024]
memory_benchmark = benchmark_memory_efficiency(models, seq_lengths)

# Visualize results
plt.figure(figsize=(16, 12))

# Memory retention analysis
plt.subplot(2, 3, 1)
if memory_retention:
    distances = [r['distance'] for r in memory_retention]
    similarities = [r['similarity'] for r in memory_retention]
    plt.scatter(distances, similarities, alpha=0.6)
    plt.xlabel('Distance from Original')
    plt.ylabel('Similarity Score')
    plt.title('Memory Retention Analysis')
    plt.grid(True, alpha=0.3)

# Memory usage comparison
plt.subplot(2, 3, 2)
for name, results in memory_benchmark.items():
    seq_lens = [r['seq_len'] for r in results if r['success']]
    memory_usage = [r['memory_mb'] for r in results if r['success']]
    if seq_lens:
        plt.plot(seq_lens, memory_usage, 'o-', label=name, linewidth=2, markersize=6)

plt.xlabel('Sequence Length')
plt.ylabel('Memory Usage (MB)')
plt.title('Memory Usage Comparison')
plt.legend()
plt.grid(True, alpha=0.3)

# Hierarchical scale visualization
plt.subplot(2, 3, 3)
if hierarchical_analysis:
    scales = hierarchical_model.scales
    scale_complexities = [1/s for s in scales]  # Inverse complexity
    plt.bar(range(len(scales)), scale_complexities, alpha=0.7)
    plt.xlabel('Scale Index')
    plt.ylabel('Relative Efficiency')
    plt.title('Hierarchical Scale Efficiency')
    plt.xticks(range(len(scales)), [f'Scale {s}' for s in scales])

# Memory retention by distance
plt.subplot(2, 3, 4)
if memory_retention:
    # Group by distance ranges
    distance_ranges = [(0, 50), (50, 100), (100, 200), (200, 400)]
    avg_similarities = []
    
    for min_dist, max_dist in distance_ranges:
        range_similarities = [r['similarity'] for r in memory_retention 
                            if min_dist <= r['distance'] < max_dist]
        avg_sim = np.mean(range_similarities) if range_similarities else 0
        avg_similarities.append(avg_sim)
    
    range_labels = [f'{min_d}-{max_d}' for min_d, max_d in distance_ranges]
    plt.bar(range_labels, avg_similarities, alpha=0.7, color='skyblue')
    plt.xlabel('Distance Range')
    plt.ylabel('Average Similarity')
    plt.title('Memory Retention by Distance')
    plt.xticks(rotation=45)

# Complexity comparison
plt.subplot(2, 3, 5)
seq_lens = np.array(seq_lengths)
standard_complexity = seq_lens ** 2  # O(n²)
hierarchical_complexity = seq_lens * np.log(seq_lens)  # O(n log n)
memory_complexity = seq_lens  # O(n) with fixed memory

plt.plot(seq_lens, standard_complexity, 'o-', label='Standard O(n²)', linewidth=2)
plt.plot(seq_lens, hierarchical_complexity, 's-', label='Hierarchical O(n log n)', linewidth=2)
plt.plot(seq_lens, memory_complexity, '^-', label='Memory-Aug O(n)', linewidth=2)
plt.xlabel('Sequence Length')
plt.ylabel('Theoretical Complexity')
plt.title('Computational Complexity')
plt.legend()
plt.yscale('log')
plt.grid(True, alpha=0.3)

# Architecture comparison summary
plt.subplot(2, 3, 6)
architectures = ['Standard', 'Memory-Aug', 'Hierarchical']
max_seq_lens = []

for arch in architectures:
    if arch in memory_benchmark:
        successful_runs = [r for r in memory_benchmark[arch] if r['success']]
        max_len = max([r['seq_len'] for r in successful_runs]) if successful_runs else 0
        max_seq_lens.append(max_len)
    else:
        max_seq_lens.append(0)

plt.bar(architectures, max_seq_lens, alpha=0.7, color=['red', 'green', 'blue'])
plt.xlabel('Architecture')
plt.ylabel('Max Sequence Length')
plt.title('Maximum Supported Sequence Length')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

print(f"\n=== LONG CONTEXT ANALYSIS ===")
print(f"Memory-Augmented Transformer:")
print(f"  Memory size: {memory_model.memory_size}")
print(f"  Memory dimension: {memory_model.memory_dim}")

print(f"\nHierarchical Attention:")
print(f"  Scales: {hierarchical_model.scales}")
print(f"  Multi-scale processing: {hierarchical_analysis['hierarchical_processing'] if hierarchical_analysis else 'N/A'}")

if memory_retention:
    avg_retention = np.mean([r['similarity'] for r in memory_retention])
    print(f"\nMemory Retention:")
    print(f"  Average similarity: {avg_retention:.3f}")
    print(f"  Total test points: {len(memory_retention)}")