# Paged Attention Implementation

This notebook demonstrates the implementation of paged attention, a key optimization technique used in efficient LLM inference engines like vLLM.

## What is Paged Attention?

Paged attention is a memory management technique that allows for more efficient use of GPU memory during attention computation. Instead of storing all key and value tensors contiguously in memory, paged attention divides them into fixed-size blocks (pages) that can be stored non-contiguously.

## Key Benefits

1. **Memory Efficiency**: Reduces memory fragmentation
2. **Dynamic Sequence Lengths**: Supports variable-length sequences without padding
3. **Better Memory Utilization**: Allows for higher batch sizes
4. **Reduced Memory Waste**: Eliminates padding overhead

In [None]:
import numpy as np
import torch

# Let's first understand the basic concept with a simple example
print("Paged Attention Concept Demonstration")
print("=====================================")

# Simulate a simple attention mechanism
def simple_attention(query, key, value):
    # Compute attention scores
    scores = torch.matmul(query, key.transpose(-2, -1))
    # Apply softmax
    weights = torch.softmax(scores, dim=-1)
    # Compute weighted sum
    output = torch.matmul(weights, value)
    return output

# Create sample data
batch_size = 2
seq_len = 4
head_dim = 8

query = torch.randn(batch_size, seq_len, head_dim)
key = torch.randn(batch_size, seq_len, head_dim)
value = torch.randn(batch_size, seq_len, head_dim)

print(f"Query shape: {query.shape}")
print(f"Key shape: {key.shape}")
print(f"Value shape: {value.shape}")

# Compute attention
attention_output = simple_attention(query, key, value)
print(f"Attention output shape: {attention_output.shape}")

## Paged Attention Concept

In traditional attention, all key and value tensors for a sequence are stored contiguously in memory. In paged attention, we divide these tensors into fixed-size blocks:

```
Traditional storage:
[K0, K1, K2, K3, K4, K5, K6, K7] - All keys stored consecutively

Paged storage:
Block 0: [K0, K1, K2, K3]
Block 1: [K4, K5, K6, K7]
```

This allows for more flexible memory management, especially when dealing with variable-length sequences.

In [None]:
# Simulate paged attention concept
class PagedAttentionSimulator:
    def __init__(self, block_size=4):
        self.block_size = block_size
        self.blocks = {}
        self.block_tables = {}
    
    def allocate_blocks(self, seq_id, seq_len):
        # Calculate number of blocks needed
        num_blocks = (seq_len + self.block_size - 1) // self.block_size
        
        # Allocate block IDs
        block_ids = []
        for i in range(num_blocks):
            block_id = len(self.blocks)
            self.blocks[block_id] = {
                'keys': np.zeros((self.block_size, 8)),  # 8 is head dimension
                'values': np.zeros((self.block_size, 8))
            }
            block_ids.append(block_id)
        
        # Store block table for this sequence
        self.block_tables[seq_id] = block_ids
        
        return block_ids
    
    def store_kv_cache(self, seq_id, keys, values):
        block_ids = self.block_tables[seq_id]
        
        # Store keys and values in blocks
        idx = 0
        for block_id in block_ids:
            block = self.blocks[block_id]
            
            # Calculate how many elements to store in this block
            remaining = len(keys) - idx
            to_store = min(remaining, self.block_size)
            
            # Store keys and values
            block['keys'][:to_store] = keys[idx:idx+to_store]
            block['values'][:to_store] = values[idx:idx+to_store]
            
            idx += to_store
            
            if idx >= len(keys):
                break
    
    def retrieve_kv_cache(self, seq_id):
        block_ids = self.block_tables[seq_id]
        
        # Retrieve all keys and values
        all_keys = []
        all_values = []
        
        for block_id in block_ids:
            block = self.blocks[block_id]
            all_keys.append(block['keys'])
            all_values.append(block['values'])
        
        return np.concatenate(all_keys, axis=0), np.concatenate(all_values, axis=0)

# Demonstrate paged attention
print("\nPaged Attention Demonstration")
print("============================")

simulator = PagedAttentionSimulator(block_size=3)

# Simulate two sequences of different lengths
seq1_len = 7  # Will need 3 blocks (3+3+1)
seq2_len = 4  # Will need 2 blocks (3+1)

# Allocate blocks for sequences
seq1_blocks = simulator.allocate_blocks(1, seq1_len)
seq2_blocks = simulator.allocate_blocks(2, seq2_len)

print(f"Sequence 1 (length {seq1_len}) allocated to blocks: {seq1_blocks}")
print(f"Sequence 2 (length {seq2_len}) allocated to blocks: {seq2_blocks}")

# Generate sample keys and values
seq1_keys = np.random.randn(seq1_len, 8)
seq1_values = np.random.randn(seq1_len, 8)
seq2_keys = np.random.randn(seq2_len, 8)
seq2_values = np.random.randn(seq2_len, 8)

# Store in paged format
simulator.store_kv_cache(1, seq1_keys, seq1_values)
simulator.store_kv_cache(2, seq2_keys, seq2_values)

# Retrieve and verify
retrieved_seq1_keys, retrieved_seq1_values = simulator.retrieve_kv_cache(1)
retrieved_seq2_keys, retrieved_seq2_values = simulator.retrieve_kv_cache(2)

print(f"\nVerification:")
print(f"Sequence 1 keys match: {np.allclose(seq1_keys, retrieved_seq1_keys[:seq1_len])}")
print(f"Sequence 1 values match: {np.allclose(seq1_values, retrieved_seq1_values[:seq1_len])}")
print(f"Sequence 2 keys match: {np.allclose(seq2_keys, retrieved_seq2_keys[:seq2_len])}")
print(f"Sequence 2 values match: {np.allclose(seq2_values, retrieved_seq2_values[:seq2_len])}")

## Benefits of Paged Attention

1. **Memory Efficiency**: By eliminating the need for padding, paged attention reduces memory waste.
2. **Dynamic Batching**: Different sequences can share the same physical memory blocks.
3. **Reduced Fragmentation**: Fixed-size blocks help reduce memory fragmentation.
4. **Scalability**: Better memory utilization allows for larger batch sizes and longer sequences.

## Implementation Considerations

When implementing paged attention in CUDA:

1. **Block Management**: Efficiently allocate and deallocate memory blocks
2. **Block Tables**: Maintain mappings between logical and physical block addresses
3. **Kernel Optimization**: Design kernels that can efficiently access non-contiguous memory
4. **Memory Coalescing**: Ensure efficient memory access patterns despite non-contiguous storage

In [None]:
# Performance comparison simulation
print("\nPerformance Comparison Simulation")
print("==================================")

# Simulate memory usage comparison
def calculate_memory_usage(batch_size, avg_seq_len, max_seq_len, head_dim, use_paged=False):
    if use_paged:
        # Paged attention - only allocate what's needed
        total_tokens = batch_size * avg_seq_len
        memory_per_token = head_dim * 2 * 4  # 2 for key+value, 4 bytes per float
        return total_tokens * memory_per_token
    else:
        # Traditional attention - allocate for max sequence length
        memory_per_sequence = max_seq_len * head_dim * 2 * 4
        return batch_size * memory_per_sequence

# Example scenario
batch_size = 32
avg_seq_len = 128
max_seq_len = 2048
head_dim = 128

traditional_memory = calculate_memory_usage(batch_size, avg_seq_len, max_seq_len, head_dim, use_paged=False)
paged_memory = calculate_memory_usage(batch_size, avg_seq_len, max_seq_len, head_dim, use_paged=True)

print(f"Batch size: {batch_size}")
print(f"Average sequence length: {avg_seq_len}")
print(f"Max sequence length: {max_seq_len}")
print(f"Head dimension: {head_dim}")
print()
print(f"Traditional attention memory usage: {traditional_memory / (1024**2):.2f} MB")
print(f"Paged attention memory usage: {paged_memory / (1024**2):.2f} MB")
print(f"Memory savings: {(traditional_memory - paged_memory) / traditional_memory * 100:.1f}%")

## Summary

Paged attention is a crucial optimization in modern LLM inference engines that provides significant memory efficiency benefits. By dividing key and value caches into fixed-size blocks, it:

1. Reduces memory waste from padding
2. Enables more efficient memory utilization
3. Supports dynamic batching with variable sequence lengths
4. Allows for better scalability to larger batch sizes

The implementation requires careful management of block allocation, block tables, and optimized CUDA kernels that can efficiently access non-contiguous memory layouts.