# Attention Optimizations for LLMs

## Overview

Attention mechanisms are the computational bottleneck in transformer models. This notebook covers:

- **FlashAttention**: Memory-efficient attention computation
- **Sparse Attention**: Reducing quadratic complexity with structured sparsity
- **Linear Attention**: Approximating attention with linear complexity
- **Performance Analysis**: Benchmarking different attention variants

Let's implement and compare various attention optimization techniques.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import time
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

print("Libraries imported successfully!")

## 1. FlashAttention Implementation

FlashAttention reduces memory usage by computing attention in blocks and avoiding materialization of the full attention matrix:

In [None]:
class FlashAttention(nn.Module):
    """Memory-efficient FlashAttention implementation"""
    
    def __init__(self, d_model, n_heads, block_size=64):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.block_size = block_size
        self.scale = 1.0 / math.sqrt(self.d_head)
        
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        B, T, C = x.shape
        
        # Project to Q, K, V
        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)
        
        # Reshape for multi-head attention
        q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        
        # Apply FlashAttention algorithm
        out = self.flash_attention_forward(q, k, v, mask)
        
        # Reshape and project output
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)
    
    def flash_attention_forward(self, q, k, v, mask=None):
        """FlashAttention forward pass with block-wise computation"""
        B, H, T, D = q.shape
        
        # Initialize output and statistics
        out = torch.zeros_like(v)
        l = torch.zeros(B, H, T, 1, device=q.device)
        m = torch.full((B, H, T, 1), -float('inf'), device=q.device)
        
        # Process in blocks
        for i in range(0, T, self.block_size):
            end_i = min(i + self.block_size, T)
            q_block = q[:, :, i:end_i, :]
            
            for j in range(0, T, self.block_size):
                end_j = min(j + self.block_size, T)
                k_block = k[:, :, j:end_j, :]
                v_block = v[:, :, j:end_j, :]
                
                # Compute attention scores for this block
                scores = torch.matmul(q_block, k_block.transpose(-2, -1)) * self.scale
                
                # Apply causal mask if needed
                if mask is not None or j >= i:  # Causal attention
                    causal_mask = torch.triu(torch.ones(end_i - i, end_j - j), diagonal=j - i + 1)
                    scores.masked_fill_(causal_mask.bool().to(scores.device), -float('inf'))
                
                # Online softmax computation
                m_block = torch.max(scores, dim=-1, keepdim=True)[0]
                scores_exp = torch.exp(scores - m_block)
                l_block = torch.sum(scores_exp, dim=-1, keepdim=True)
                
                # Update global statistics
                m_prev = m[:, :, i:end_i, :]
                l_prev = l[:, :, i:end_i, :]
                
                m_new = torch.max(m_prev, m_block)
                alpha = torch.exp(m_prev - m_new)
                beta = torch.exp(m_block - m_new)
                
                l_new = alpha * l_prev + beta * l_block
                
                # Update output
                out_block = torch.matmul(scores_exp, v_block)
                out[:, :, i:end_i, :] = (alpha * out[:, :, i:end_i, :] + beta * out_block) / l_new
                
                # Update statistics
                m[:, :, i:end_i, :] = m_new
                l[:, :, i:end_i, :] = l_new
        
        return out

class SparseAttention(nn.Module):
    """Sparse attention with configurable sparsity patterns"""
    
    def __init__(self, d_model, n_heads, sparsity_pattern='local', window_size=128):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.sparsity_pattern = sparsity_pattern
        self.window_size = window_size
        self.scale = 1.0 / math.sqrt(self.d_head)
        
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        
        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)
        
        q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        
        # Create sparsity mask
        mask = self.create_sparsity_mask(T)
        
        # Compute sparse attention
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        scores.masked_fill_(~mask, -float('inf'))
        
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)
    
    def create_sparsity_mask(self, seq_len):
        """Create sparsity mask based on pattern"""
        if self.sparsity_pattern == 'local':
            return self.local_mask(seq_len)
        elif self.sparsity_pattern == 'strided':
            return self.strided_mask(seq_len)
        elif self.sparsity_pattern == 'random':
            return self.random_mask(seq_len)
        else:
            return torch.ones(seq_len, seq_len, dtype=torch.bool)
    
    def local_mask(self, seq_len):
        """Local attention within window"""
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = True
        return mask
    
    def strided_mask(self, seq_len):
        """Strided attention pattern"""
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
        stride = self.window_size // 4
        
        for i in range(seq_len):
            # Local attention
            local_start = max(0, i - self.window_size // 2)
            local_end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, local_start:local_end] = True
            
            # Strided attention
            for j in range(0, i, stride):
                if j < seq_len:
                    mask[i, j] = True
        
        return mask
    
    def random_mask(self, seq_len, sparsity=0.1):
        """Random sparsity pattern"""
        mask = torch.rand(seq_len, seq_len) < sparsity
        # Ensure causal mask
        causal = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
        return mask & causal

print("Attention optimization modules implemented!")

### Testing Attention Optimizations

Let's benchmark different attention mechanisms:

In [None]:
def benchmark_attention(attention_module, seq_lengths, d_model=512, n_heads=8, batch_size=4):
    """Benchmark attention module across different sequence lengths"""
    results = []
    
    for seq_len in seq_lengths:
        # Create random input
        x = torch.randn(batch_size, seq_len, d_model)
        
        # Warmup
        for _ in range(3):
            _ = attention_module(x)
        
        # Benchmark
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        start_time = time.time()
        
        for _ in range(10):
            output = attention_module(x)
        
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        end_time = time.time()
        
        avg_time = (end_time - start_time) / 10
        
        results.append({
            'seq_len': seq_len,
            'time': avg_time,
            'memory': torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
        })
        
        print(f"Seq len {seq_len}: {avg_time:.4f}s")
    
    return results

# Test different attention mechanisms
d_model, n_heads = 512, 8
seq_lengths = [128, 256, 512, 1024]

print("Benchmarking Standard Attention...")
standard_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
standard_results = []

for seq_len in seq_lengths:
    x = torch.randn(4, seq_len, d_model)
    
    start_time = time.time()
    for _ in range(10):
        output, _ = standard_attn(x, x, x)
    end_time = time.time()
    
    avg_time = (end_time - start_time) / 10
    standard_results.append({'seq_len': seq_len, 'time': avg_time})
    print(f"Seq len {seq_len}: {avg_time:.4f}s")

print("\nBenchmarking FlashAttention...")
flash_attn = FlashAttention(d_model, n_heads, block_size=64)
flash_results = benchmark_attention(flash_attn, seq_lengths)

print("\nBenchmarking Sparse Attention (Local)...")
sparse_attn_local = SparseAttention(d_model, n_heads, 'local', window_size=64)
sparse_local_results = benchmark_attention(sparse_attn_local, seq_lengths)

print("\nBenchmarking Sparse Attention (Strided)...")
sparse_attn_strided = SparseAttention(d_model, n_heads, 'strided', window_size=64)
sparse_strided_results = benchmark_attention(sparse_attn_strided, seq_lengths)

# Visualize results
plt.figure(figsize=(15, 10))

# Performance comparison
plt.subplot(2, 2, 1)
plt.plot([r['seq_len'] for r in standard_results], [r['time'] for r in standard_results], 'o-', label='Standard', linewidth=2)
plt.plot([r['seq_len'] for r in flash_results], [r['time'] for r in flash_results], 's-', label='FlashAttention', linewidth=2)
plt.plot([r['seq_len'] for r in sparse_local_results], [r['time'] for r in sparse_local_results], '^-', label='Sparse (Local)', linewidth=2)
plt.plot([r['seq_len'] for r in sparse_strided_results], [r['time'] for r in sparse_strided_results], 'd-', label='Sparse (Strided)', linewidth=2)
plt.xlabel('Sequence Length')
plt.ylabel('Time (seconds)')
plt.title('Attention Performance Comparison')
plt.legend()
plt.grid(True, alpha=0.3)

# Complexity analysis
plt.subplot(2, 2, 2)
seq_lens = np.array([r['seq_len'] for r in standard_results])
standard_times = np.array([r['time'] for r in standard_results])
flash_times = np.array([r['time'] for r in flash_results])

# Fit quadratic and linear curves
quadratic_fit = np.polyfit(seq_lens, standard_times, 2)
linear_fit = np.polyfit(seq_lens, flash_times, 1)

x_smooth = np.linspace(seq_lens.min(), seq_lens.max(), 100)
plt.plot(seq_lens, standard_times, 'o', label='Standard (Measured)', markersize=8)
plt.plot(x_smooth, np.polyval(quadratic_fit, x_smooth), '--', label='O(nÂ²) Fit', alpha=0.7)
plt.plot(seq_lens, flash_times, 's', label='FlashAttention (Measured)', markersize=8)
plt.plot(x_smooth, np.polyval(linear_fit, x_smooth), '--', label='O(n) Fit', alpha=0.7)
plt.xlabel('Sequence Length')
plt.ylabel('Time (seconds)')
plt.title('Complexity Analysis')
plt.legend()
plt.grid(True, alpha=0.3)

# Speedup analysis
plt.subplot(2, 2, 3)
flash_speedup = [s['time'] / f['time'] for s, f in zip(standard_results, flash_results)]
sparse_local_speedup = [s['time'] / sp['time'] for s, sp in zip(standard_results, sparse_local_results)]
sparse_strided_speedup = [s['time'] / sp['time'] for s, sp in zip(standard_results, sparse_strided_results)]

plt.bar(np.arange(len(seq_lengths)) - 0.2, flash_speedup, 0.2, label='FlashAttention', alpha=0.8)
plt.bar(np.arange(len(seq_lengths)), sparse_local_speedup, 0.2, label='Sparse (Local)', alpha=0.8)
plt.bar(np.arange(len(seq_lengths)) + 0.2, sparse_strided_speedup, 0.2, label='Sparse (Strided)', alpha=0.8)
plt.xlabel('Sequence Length')
plt.ylabel('Speedup vs Standard')
plt.title('Speedup Comparison')
plt.xticks(range(len(seq_lengths)), seq_lengths)
plt.legend()
plt.grid(True, alpha=0.3)

# Attention pattern visualization
plt.subplot(2, 2, 4)
seq_len = 64
local_mask = sparse_attn_local.create_sparsity_mask(seq_len).float().numpy()
plt.imshow(local_mask, cmap='Blues', aspect='auto')
plt.title('Local Sparse Attention Pattern')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.colorbar()

plt.tight_layout()
plt.show()

print(f"\n=== PERFORMANCE SUMMARY ===")
print(f"At sequence length 1024:")
print(f"  Standard Attention: {standard_results[-1]['time']:.4f}s")
print(f"  FlashAttention: {flash_results[-1]['time']:.4f}s ({standard_results[-1]['time']/flash_results[-1]['time']:.1f}x speedup)")
print(f"  Sparse Local: {sparse_local_results[-1]['time']:.4f}s ({standard_results[-1]['time']/sparse_local_results[-1]['time']:.1f}x speedup)")
print(f"  Sparse Strided: {sparse_strided_results[-1]['time']:.4f}s ({standard_results[-1]['time']/sparse_strided_results[-1]['time']:.1f}x speedup)")