In [1]:
import torch
import torch.nn as nn
from typing import Optional, Tuple

In [2]:
class KVCache:
    """
    Key-Value cache for transformer attention mechanism.
    Stores past key and value states to avoid recomputation during autoregressive generation.
    """
    def __init__(self, batch_size: int, max_seq_len: int, num_heads: int, head_dim: int, device: str = 'cpu'):
        """
        Initialize KV cache.

        Args:
            batch_size: Batch size
            max_seq_len: Maximum sequence length to cache
            num_heads: Number of attention heads
            head_dim: Dimension of each attention head
            device: Device to store cache on ('cpu' or 'cuda')
        """
        self.batch_size = batch_size
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.device = device

        # Initialize cache tensors
        self.k_cache = torch.zeros(
            batch_size, num_heads, max_seq_len, head_dim,
            dtype=torch.float32, device=device
        )
        self.v_cache = torch.zeros(
            batch_size, num_heads, max_seq_len, head_dim,
            dtype=torch.float32, device=device
        )

        # Track current position in cache
        self.current_len = 0

    def update(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Update cache with new key and value tensors.

        Args:
            k: New keys of shape (batch_size, num_heads, seq_len, head_dim)
            V: New values of shape (batch_size, num_heads, seq_len, head_dim)

        Returns:
            Tuple of (all_keys, all_values) including cached and new states
        """
        seq_len = k.shape[2]

        # Check if we need to extend cache
        if self.current_len + seq_len > self.max_seq_len:
            raise ValueError(f"Cache overflow: current_len={self.current_len}, new_seq_len={seq_len}, max_seq_len={self.max_seq_len}")

        # Update cache at current position
        self.k_cache[:, :, self.current_len:self.current_len + seq_len, :] = k
        self.v_cache[:, :, self.current_len:self.current_len + seq_len, :] = v

        # Increment position
        self.current_len += seq_len

        # Return all cached keys and values up to current position
        return (
            self.k_cache[:, :, :self.current_len, :],
            self.v_cache[:, :, :self.current_len, :],
        )

    def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get current cached keys and values.

        Returns:
            Tuple of (keys, values) from cache
        """
        return (
            self.k_cache[:, :, :self.current_len, :],
            self.v_cache[:, :, :self.current_len, :],
        )

    def reset(self):
        """Reset cache to empty state."""
        self.current_len = 0
        self.k_cache.zero_()
        self.v_cache.zero_()

    def get_seq_len(self) -> int:
        """Get current sequence length in cache."""
        return self.current_len

In [3]:
class MultiHeadAttentionWithCache(nn.Module):
    """
    Multi-head attention with KV cache support for efficient autoregressive generation.
    """
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.scale = self.head_dim ** -0.5

        # Linear projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, kv_cache: Optional[KVCache] = None, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass with optional KV cache.

        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            kv_cache: Optional KV cache for autoregressive generation
            mask: Optional attention mask

        Returns:
            Output tensor of shape (batch_size, seq_len, d_model)
        """
        B, L, D = x.shape

        # Project queries, keys, values
        q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

        # Use cache if provided
        if kv_cache is not None:
            k, v = kv_cache.update(k, v)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Apply softmax and dropout
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # Apply attention to values
        out = torch.matmul(attn, v)

        # Reshape and project output
        out = out.transpose(1, 2).contiguous().view(B, L, D)
        out = self.out_proj(out)

        return out

In [4]:
# Example usage
if __name__ == "__main__":
    # Configuration
    batch_size = 2
    seq_len = 10
    d_model = 512
    num_heads = 8
    max_seq_len = 100

    # Create attention layer
    attention = MultiHeadAttentionWithCache(d_model, num_heads)

    # Create KV cache
    head_dim = d_model // num_heads
    kv_cache = KVCache(batch_size, max_seq_len, num_heads, head_dim)

    print("=== Autoregressive Generation with KV Cache ===\n")

    # Simulate autoregressive generation
    for step in range(3):
        # In generation, we typically process one token at a time
        token_len = 1 if step > 0 else seq_len  # First step: full sequence, then one token
        x = torch.randn(batch_size, token_len, d_model)

        print(f"Step {step + 1}:")
        print(f"  Input shape: {x.shape}")
        print(f"  Cache length before: {kv_cache.get_seq_len()}")

        # Forward pass with cache
        output = attention(x, kv_cache=kv_cache)

        print(f"  Cache length after: {kv_cache.get_seq_len()}")
        print(f"  Output shape: {output.shape}\n")

    # Reset cache
    print("Resetting cache...")
    kv_cache.reset()
    print(f"Cache length after reset: {kv_cache.get_seq_len()}\n")

    # Compare with and without cache (timing)
    print("=== Performance Comparison ===")
    import time

    # Without cache
    x_full = torch.randn(batch_size, 50, d_model)
    start = time.time()
    for _ in range(10):
        _ = attention(x_full, kv_cache=None)
    no_cache_time = time.time() - start
    print(f"Without cache (10 iterations): {no_cache_time:.4f}s")

    # With cache (simulating generation)
    kv_cache.reset()
    start = time.time()
    for i in range(10):
        x_token = torch.randn(batch_size, 1, d_model)
        _ = attention(x_token, kv_cache=kv_cache)
    cache_time = time.time() - start
    print(f"With cache (10 tokens): {cache_time:.4f}s")
    print(f"Speedup: {no_cache_time / cache_time:.2f}x")

=== Autoregressive Generation with KV Cache ===

Step 1:
  Input shape: torch.Size([2, 10, 512])
  Cache length before: 0
  Cache length after: 10
  Output shape: torch.Size([2, 10, 512])

Step 2:
  Input shape: torch.Size([2, 1, 512])
  Cache length before: 10
  Cache length after: 11
  Output shape: torch.Size([2, 1, 512])

Step 3:
  Input shape: torch.Size([2, 1, 512])
  Cache length before: 11
  Cache length after: 12
  Output shape: torch.Size([2, 1, 512])

Resetting cache...
Cache length after reset: 0

=== Performance Comparison ===
Without cache (10 iterations): 0.0082s
With cache (10 tokens): 0.0029s
Speedup: 2.87x
