In [1]:
import torch
import torch.nn.functional as F
import math

In [2]:
def standard_attention(Q, K, V, mask=None):
    """Standard attention implementation for comparison."""
    d_k = Q.size(-1)

    scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)

    return output

In [3]:
def flash_attention_forward(Q, K, V, block_size=64):
    """
    Simplified Flash Attention implementation.

    Args:
        Q, K, V: Query, Key, Value matrices [batch, num_heads, seq_len, head_dim]
        block_size: Size of blocks for tiling (should fit in SRAM)

    Returns:
        output: Attention output [batch, num_heads, seq_len, head_dim]
    """
    batch_size, num_heads, seq_len, head_dim = Q.shape
    scale = 1.0 / math.sqrt(head_dim)

    # Initialize output and normalization statistics
    O = torch.zeros_like(Q)
    l = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)  # row sum
    m = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)  # row max

    # Number of blocks
    num_blocks = (seq_len + block_size - 1) // block_size

    # Outer loop: iterate over query blocks
    for i in range(num_blocks):
        q_start = i * block_size
        q_end = min((i + 1) * block_size, seq_len)
        Q_block = Q[:, :, q_start:q_end, :]  # [batch, num_heads, block_size, D]

        # Initialize block outputs
        O_block = torch.zeros_like(Q_block)
        l_block = torch.zeros(batch_size, num_heads, q_end - q_start, 1, device=Q.device)
        m_block = torch.full((batch_size, num_heads, q_end - q_start, 1), float('-inf'), device=Q.device)

        # Inner loop: iterate over key-value blocks
        for j in range(num_blocks):
            k_start = j * block_size
            k_end = min((j + 1) * block_size, seq_len)
            K_block = K[:, :, k_start:k_end, :]  # [B, H, block_size, D]
            V_block = V[:, :, k_start:k_end, :]  # [B, H, block_size, D]

            # Compute attention scores for this block
            S_block = torch.matmul(Q_block, K_block.transpose(-1, -2)) * scale  # [B, H, q_block, k_block]

            # Online softmax: update running statistics
            m_new = torch.maximum(m_block, S_block.max(dim=-1, keepdim=True)[0])

            # Compute exponentials with numerical stability
            exp_scores = torch.exp(S_block - m_new)
            exp_m_diff = torch.exp(m_block - m_new)

            # Update row sum
            l_new = exp_m_diff * l_block + exp_scores.sum(dim=-1, keepdim=True)

            # Update output (weighted combination)
            O_block = exp_m_diff * O_block + torch.matmul(exp_scores, V_block)

            # Update statistics for next iteration
            m_block = m_new
            l_block = l_new

        # Normalize the block output
        O[:, :, q_start:q_end, :] = O_block / l_block
        m[:, :, q_start:q_end, :] = m_block
        l[:, :, q_start:q_end, :] = l_block

    return O

In [4]:
def flash_attention_causal(Q, K, V, block_size=64):
    """
    Flash Attention with causal masking (for autoregressive models)

    Args:
        Q, K, V: Query, Key, Value matrices [batch, num_heads, seq_len, head_dim]
        block_size: Size of blocks for tiling (should fit in SRAM)

    Returns:
        output: Attention output with causal masking
    """
    batch_size, num_heads, seq_len, head_dim = Q.shape
    scale = 1.0 / math.sqrt(head_dim)

    O = torch.zeros_like(Q)
    l = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)
    m = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)

    num_blocks = (seq_len + block_size - 1) // block_size

    for i in range(num_blocks):
        q_start = i * block_size
        q_end = min((i + 1) * block_size, seq_len)
        Q_block = Q[:, :, q_start:q_end, :]

        O_block = torch.zeros_like(Q_block)
        l_block = torch.zeros(batch_size, num_heads, q_end - q_start, 1, device=Q.device)
        m_block = torch.full((batch_size, num_heads, q_end - q_start, 1), float('-inf'), device=Q.device)

        # Only attend to previous blocks (causal)
        for j in range(i + 1):  # j <= i ensures causality
            k_start = j * block_size
            k_end = min((j + 1) * block_size, seq_len)
            K_block = K[:, :, k_start:k_end, :]
            V_block = V[:, :, k_start:k_end, :]

            S_block = torch.matmul(Q_block, K_block.transpose(-1, -2)) * scale

            # Apply causal mask within block
            if i == j:  # Same block - need causal mask
                q_indices = torch.arange(q_start, q_end, device=Q.device).unsqueeze(1)
                k_indices = torch.arange(k_start, k_end, device=Q.device).unsqueeze(0)
                causal_mask = q_indices >= k_indices
                S_block = S_block.masked_fill(~causal_mask, float('-inf'))

            m_new = torch.maximum(m_block, S_block.max(dim=-1, keepdim=True)[0])
            exp_scores = torch.exp(S_block - m_new)
            exp_m_diff = torch.exp(m_block - m_new)

            l_new = exp_m_diff * l_block + exp_scores.sum(dim=-1, keepdim=True)
            O_block = exp_m_diff * O_block + torch.matmul(exp_scores, V_block)

            m_block = m_new
            l_block = l_new

        O[:, :, q_start:q_end, :] = O_block / l_block
        m[:, :, q_start:q_end, :] = m_block
        l[:, :, q_start:q_end, :] = l_block

    return O

In [5]:
# Example usage and verification
if __name__ == "__main__":
    # Setup
    batch_size = 2
    num_heads = 8
    seq_len = 128
    head_dim = 64

    torch.manual_seed(42)
    Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda' if torch.cuda.is_available() else 'cpu')
    K = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda' if torch.cuda.is_available() else 'cpu')
    V = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda' if torch.cuda.is_available() else 'cpu')

    # Compare standard vs flash attention
    output_standard = standard_attention(Q, K, V)
    output_flash = flash_attention_forward(Q, K, V, block_size=32)

    # Check if outputs match (they should be very close)
    diff = (output_standard - output_flash).abs().max()
    print(f"Max difference between standard and flash attention: {diff.item():.6e}")
    print(f"Outputs match: {torch.allclose(output_standard, output_flash, atol=1e-4)}")

    # Test causal version
    output_flash_causal = flash_attention_causal(Q, K, V, block_size=32)

    # Create causal mask for standard attention
    causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=Q.device)).unsqueeze(0).unsqueeze(0)
    output_standard_causal = standard_attention(Q, K, V, mask=causal_mask)

    diff_causal = (output_standard_causal - output_flash_causal).abs().max()
    print(f"\nMax difference for causal attention: {diff_causal.item():.6e}")
    print(f"Causal outputs match: {torch.allclose(output_standard_causal, output_flash_causal, atol=1e-4)}")

    print(f"\nInput shape: {Q.shape}")
    print(f"Output shape: {output_flash.shape}")

Max difference between standard and flash attention: 4.768372e-07
Outputs match: True

Max difference for causal attention: 4.768372e-07
Causal outputs match: True

Input shape: torch.Size([2, 8, 128, 64])
Output shape: torch.Size([2, 8, 128, 64])
