In [2]:
import torch
import numpy as np
import math

####################################
# 1. Generate toy Q/K/V
####################################

torch.manual_seed(0)
np.random.seed(0)

B = 2
T = 7
D = 16

Q_t = torch.randn((B, T, D), dtype=torch.float32)
K_t = torch.randn((B, T, D), dtype=torch.float32)
V_t = torch.randn((B, T, D), dtype=torch.float32)

Q = Q_t.numpy()
K = K_t.numpy()
V = V_t.numpy()


In [None]:
####################################
# 2. Exact attention (reference)
####################################
scores = Q_t @ K_t.transpose(1, 2) / math.sqrt(D)
attn_probs = torch.softmax(scores, dim=-1)
out_exact = attn_probs @ V_t
out_exact_np = out_exact.numpy()


####################################
# 3. FlashAttention (PyTorch)
####################################
def flash_attention_torch(Q, K, V):
    Qh = Q.unsqueeze(1)  # (B, 1, T, D)
    Kh = K.unsqueeze(1)
    Vh = V.unsqueeze(1)
    out = torch.nn.functional.scaled_dot_product_attention(Qh, Kh, Vh, dropout_p=0.0)
    return out.squeeze(1)


out_torch_flash = flash_attention_torch(Q_t, K_t, V_t).numpy()


####################################
# 4. FlashAttention (NumPy)
####################################
def flash_attention_numpy(Q, K, V, block_size=4, causal=False):
    """
    FlashAttention-style streaming implementation in NumPy.
    Q: [B, Tq, D]
    K: [B, Tk, D]
    V: [B, Tk, Dv]
    """
    B, Tq, d = Q.shape
    _, Tk, dv = V.shape
    scale = 1.0 / math.sqrt(d)

    O = np.zeros((B, Tq, dv), dtype=np.float32)

    for b in range(B):
        Qb = Q[b]
        Kb = K[b]
        Vb = V[b]

        # Process queries in blocks
        for qi in range(0, Tq, block_size):
            q_end = min(qi + block_size, Tq)
            Q_blk = Qb[qi:q_end]  # [bq, d]
            bq = Q_blk.shape[0]

            # Running accumulators (per row)
            running_max = np.full(bq, -1e9, dtype=np.float32)
            running_sum = np.zeros(bq, dtype=np.float32)  # L accumulator
            running_Y = np.zeros((bq, dv), dtype=np.float32)

            # Process keys in blocks
            for kj in range(0, Tk, block_size):
                k_end = min(kj + block_size, Tk)
                K_blk = Kb[kj:k_end]  # [bk, d]
                V_blk = Vb[kj:k_end]  # [bk, dv]

                # Score block: [bq, bk]
                S = (Q_blk @ K_blk.T) * scale

                # Optional causal mask
                if causal:
                    qpos = np.arange(qi, q_end)[:, None]
                    kpos = np.arange(kj, k_end)[None, :]
                    S = np.where(qpos < kpos, -1e9, S)

                # Local block max
                block_max = S.max(axis=-1)  # [bq]
                # Merge into running accumulator (log-sum-exp merge)
                M = np.maximum(running_max, block_max)  # new global max
                # Compensation factor
                compensation_factor = np.exp(running_max - M)

                exp_S = np.exp(S - M[:, None])  # [bq, bk]
                block_sum = exp_S.sum(axis=-1)  # [bq]
                block_Y = exp_S @ V_blk  # [bq, dv]

                running_Y = running_Y * compensation_factor[:, None] + block_Y

                running_sum = running_sum * compensation_factor + block_sum

                running_max = M

            # Final output for this block of queries
            O[b, qi:q_end] = running_Y / running_sum[:, None]

    return O


out_numpy_flash = flash_attention_numpy(Q, K, V, block_size=4)

####################################
# 5. Compare all results
####################################
print(
    "Max diff (Torch Flash vs Exact):     ",
    np.max(np.abs(out_torch_flash - out_exact_np)),
)
print(
    "Max diff (NumPy Flash vs Exact):     ",
    np.max(np.abs(out_numpy_flash - out_exact_np)),
)
print(
    "Max diff (NumPy Flash vs Torch):     ",
    np.max(np.abs(out_numpy_flash - out_torch_flash)),
)

print("\nExample row (batch0, first 2 rows):")
print("Exact:\n", out_exact_np[0, :2, :4])
print("Torch Flash:\n", out_torch_flash[0, :2, :4])
print("NumPy Flash:\n", out_numpy_flash[0, :2, :4])

Max diff (Torch Flash vs Exact):      1.7881393e-07
Max diff (NumPy Flash vs Exact):      2.3841858e-07
Max diff (NumPy Flash vs Torch):      2.3841858e-07

Example row (batch0, first 2 rows):
Exact:
 [[-0.20096852 -0.5869937  -0.05182338 -0.4397468 ]
 [-0.1666379  -0.5908006  -0.6343283  -0.5527998 ]]
Torch Flash:
 [[-0.20096853 -0.5869937  -0.05182336 -0.4397468 ]
 [-0.1666379  -0.5908006  -0.63432837 -0.5527999 ]]
NumPy Flash:
 [[-0.20096852 -0.58699363 -0.05182335 -0.43974683]
 [-0.16663791 -0.5908006  -0.63432837 -0.5527999 ]]
