In [None]:
import torch
import numpy as np
import math

####################################
# 1. Generate toy Q/K/V
####################################

torch.manual_seed(0)
np.random.seed(0)

B = 2
T = 7
D = 16

Q_t = torch.randn((B, T, D), dtype=torch.float32)
K_t = torch.randn((B, T, D), dtype=torch.float32)
V_t = torch.randn((B, T, D), dtype=torch.float32)

Q = Q_t.numpy()
K = K_t.numpy()
V = V_t.numpy()

In [None]:
####################################
# 2. Exact attention (reference)
####################################
scores = Q_t @ K_t.transpose(1, 2) / math.sqrt(D)
attn_probs = torch.softmax(scores, dim=-1)
out_exact = attn_probs @ V_t
out_exact_np = out_exact.numpy()


####################################
# 3. FlashAttention (PyTorch)
####################################
def flash_attention_torch(Q, K, V):
    Qh = Q.unsqueeze(1)  # (B, 1, T, D)
    Kh = K.unsqueeze(1)
    Vh = V.unsqueeze(1)
    out = torch.nn.functional.scaled_dot_product_attention(Qh, Kh, Vh, dropout_p=0.0)
    return out.squeeze(1)


out_torch_flash = flash_attention_torch(Q_t, K_t, V_t).numpy()


####################################
# 4. FlashAttention (NumPy)
####################################
def flash_attention_numpy(Q, K, V, block_size=4, causal=False):
    """
    FlashAttention-style streaming implementation in NumPy.
    Q: [B, Tq, D]
    K: [B, Tk, D]
    V: [B, Tk, Dv]
    """
    B, Tq, d = Q.shape
    _, Tk, dv = V.shape
    scale = 1.0 / math.sqrt(d)

    O = np.zeros((B, Tq, dv), dtype=np.float32)

    for b in range(B):
        Qb = Q[b]
        Kb = K[b]
        Vb = V[b]

        # Process queries in blocks
        for qi in range(0, Tq, block_size):
            q_end = min(qi + block_size, Tq)
            Q_blk = Qb[qi:q_end]  # [bq, d]
            bq = Q_blk.shape[0]

            # Running accumulators (per row)
            running_max = np.full(bq, -1e9, dtype=np.float32)
            running_sum = np.zeros(bq, dtype=np.float32)  # L accumulator
            running_Y = np.zeros((bq, dv), dtype=np.float32)

            # Process keys in blocks
            for kj in range(0, Tk, block_size):
                k_end = min(kj + block_size, Tk)
                K_blk = Kb[kj:k_end]  # [bk, d]
                V_blk = Vb[kj:k_end]  # [bk, dv]

                # Score block: [bq, bk]
                S = (Q_blk @ K_blk.T) * scale

                # Optional causal mask
                if causal:
                    qpos = np.arange(qi, q_end)[:, None]
                    kpos = np.arange(kj, k_end)[None, :]
                    S = np.where(qpos < kpos, -1e9, S)

                # Local block max
                block_max = S.max(axis=-1)  # [bq]
                # Merge into running accumulator (log-sum-exp merge)
                M = np.maximum(running_max, block_max)  # new global max
                # Compensation factor
                compensation_factor = np.exp(running_max - M)

                exp_S = np.exp(S - M[:, None])  # [bq, bk]
                block_sum = exp_S.sum(axis=-1)  # [bq]
                block_Y = exp_S @ V_blk  # [bq, dv]

                running_Y = running_Y * compensation_factor[:, None] + block_Y

                running_sum = running_sum * compensation_factor + block_sum

                running_max = M

            # Final output for this block of queries
            O[b, qi:q_end] = running_Y / running_sum[:, None]

    return O


out_numpy_flash = flash_attention_numpy(Q, K, V, block_size=4)

####################################
# 5. Compare all results
####################################
print(
    "Max diff (Torch Flash vs Exact):     ",
    np.max(np.abs(out_torch_flash - out_exact_np)),
)
print(
    "Max diff (NumPy Flash vs Exact):     ",
    np.max(np.abs(out_numpy_flash - out_exact_np)),
)
print(
    "Max diff (NumPy Flash vs Torch):     ",
    np.max(np.abs(out_numpy_flash - out_torch_flash)),
)

print("\nExample row (batch0, first 2 rows):")
print("Exact:\n", out_exact_np[0, :2, :4])
print("Torch Flash:\n", out_torch_flash[0, :2, :4])
print("NumPy Flash:\n", out_numpy_flash[0, :2, :4])

# %%
import numpy as np

block_size = 2
q = np.array(
    [
        [1, 2, 2],
        [1, 1, 2],
        [1, 2, 1],
        [1, 1, 1],
        [1, 5, 1],
        [3, 1, 0],
    ]
)

k = np.array(
    [
        [3, 2, 2],
        [3, 1, 2],
        [1, 3, 1],
        [1, 1, 3],
    ]
)

v = np.array(
    [
        [1, 2, 4],
        [4, 1, 2],
        [4, 2, 1],
        [1, 1, 4],
    ]
)

```bash
#### Exact implementation ####
>>> s = np.dot(q, k.T)
>>> s
array([[11,  9,  9,  9],
       [ 9,  8,  6,  8],
       [ 9,  7,  8,  6],
       [ 7,  6,  5,  5],
       [15, 10, 17,  9],
       [11, 10,  6,  4]])
>>> s_exp = np.exp(s)
>>> s_prob = s_exp / s_exp.sum(axis=-1, keepdims=True)
>>> s_prob
array([[7.11234594e-01, 9.62551353e-02, 9.62551353e-02, 9.62551353e-02],
       [5.60052795e-01, 2.06031909e-01, 2.78833868e-02, 2.06031909e-01],
       [6.43914260e-01, 8.71443187e-02, 2.36882818e-01, 3.20586033e-02],
       [6.10295685e-01, 2.24515236e-01, 8.25945394e-02, 8.25945394e-02],
       [1.19072103e-01, 8.02301516e-04, 8.79830446e-01, 2.95150233e-04],
       [7.26992890e-01, 2.67445738e-01, 4.89843956e-03, 6.62931706e-04]])
>>> np.dot(s_prob, v) # target
array([[1.57753081, 1.80748973, 3.51872432],
       [1.70174589, 1.58793618, 3.50428602],
       [1.97208141, 1.88079708, 3.11506291],
       [1.92132933, 1.69289022, 3.30318591],
       [3.64189824, 1.99890255, 1.35890406],
       [1.81703253, 1.73189133, 3.4504132 ]])

#### FlashAttention ####
>>> s11 = np.dot(q[:2], k[0:2].T); max11 = s11.max(axis=-1, keepdims=True); M11=np.maximum(0, max11); f11=np.exp(max11-M11); exp11=np.exp(s11-M11); exp_sum11=exp11.sum(axis=-1,keepdims=True); prob11 = exp11 / exp_sum11 + 0*f11;
>>> s12 = np.dot(q[:2], k[2:4].T); max12 = s12.max(axis=-1, keepdims=True); M12=np.maximum(M11, max11); f12=np.exp(max12-M12); exp12=np.exp(s12-M12); exp_sum12=exp12.sum(axis=-1,keepdims=True); prob12 = exp12 / exp_sum12 + prob11*f12;
>>> s21 = np.dot(q[2:4], k[0:2].T); max21 = s21.max(axis=-1, keepdims=True)
>>> s22 = np.dot(q[2:4], k[2:4].T); max22 = s22.max(axis=-1, keepdims=True)
>>> s31 = np.dot(q[4:6], k[0:2].T); max31 = s31.max(axis=-1, keepdims=True)
>>> s32 = np.dot(q[4:6], k[2:4].T); max32 = s32.max(axis=-1, keepdims=True)

>>> s11
array([[11,  9],
       [ 9,  8]])
>>> max11
array([11,  9])

>>> s12
array([[9, 9],
       [6, 8]])
>>> max12
array([9, 8])
>>> s21
array([[9, 7],
       [7, 6]])
>>> max22
array([8, 5])
>>> s31
array([[17,  9],
       [ 6,  4]])
>>> max31
array([15, 11])
>>> s32
array([[17,  9],
       [ 6,  4]])
>>> max32
array([17,  6])


running_max = np.full(block_size, -1e9, dtype=np.float32)
running_sum = np.zeros(block_size, dtype=np.float32)  # L accumulator
running_Y = np.zeros((block_size, v.shape[-1]), dtype=np.float32)
O = np.zeros((q.shape[0], v.shape[-1]), dtype=np.float32)

V_blk = v[:2]
blk1 = np.dot(q[:2], k[:2].T)
blk1_max = blk1.max(axis=-1)
M = np.maximum(running_max, blk1_max)


# Compensation factor
compensation_factor = np.exp(running_max - M)

exp_blk1 = np.exp(blk1 - M[:, None])  # [bq, bk]
block_sum = exp_blk1.sum(axis=-1)  # [bq]
block_Y = exp_blk1 @ V_blk  # [bq, dv]

running_Y = running_Y * compensation_factor[:, None] + block_Y
running_sum = running_sum * compensation_factor + block_sum

running_max = M

O[0:2, ...] = running_Y / running_sum[:, None]

```

In [None]:
from graphviz import Digraph

dot = Digraph(comment="FlashAttention")
dot.attr(rankdir="LR")

dot.node("Q", "Q\n(L√ód)", shape="box", style="filled", fillcolor="lightblue")
dot.node("K", "K\n(L√ód)", shape="box", style="filled", fillcolor="lightgreen")
dot.node("QK", "QK·µÄ\n(L√óL)", shape="box", style="filled", fillcolor="lightyellow")

dot.edge("Q", "QK")
dot.edge("QK", "K")

dot

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

# Q block
fig.add_shape(
    type="rect", x0=0, y0=0, x1=1, y1=4, line=dict(color="blue"), fillcolor="lightblue"
)
fig.add_annotation(x=0.5, y=4.3, text="Q", showarrow=False)

# K block
fig.add_shape(
    type="rect",
    x0=4,
    y0=0,
    x1=5,
    y1=4,
    line=dict(color="green"),
    fillcolor="lightgreen",
)
fig.add_annotation(x=4.5, y=4.3, text="K", showarrow=False)

# QK^T block
fig.add_shape(
    type="rect",
    x0=1.5,
    y0=0,
    x1=3.5,
    y1=4,
    line=dict(color="orange"),
    fillcolor="lightyellow",
)
fig.add_annotation(x=2.5, y=4.3, text="QK<sup>T</sup>", showarrow=False)

fig.update_layout(showlegend=False, xaxis_visible=False, yaxis_visible=False)
fig.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

fig, ax = plt.subplots(figsize=(8, 5))

# Q vector
q_rect = patches.Rectangle(
    (0, 0), 1, 4, linewidth=1, edgecolor="blue", facecolor="lightblue"
)
ax.add_patch(q_rect)
ax.text(0.5, 4.3, "Q", ha="center", fontsize=12)
ax.text(0.5, -0.3, r"$L \times d$", ha="center", fontsize=10)

# K vector
k_rect = patches.Rectangle(
    (4, 0), 1, 4, linewidth=1, edgecolor="green", facecolor="lightgreen"
)
ax.add_patch(k_rect)
ax.text(4.5, 4.3, "K", ha="center", fontsize=12)
ax.text(4.5, -0.3, r"$L \times d$", ha="center", fontsize=10)

# QK^T matrix
qk_rect = patches.Rectangle(
    (1.5, 0), 2, 4, linewidth=1, edgecolor="orange", facecolor="lightyellow"
)
ax.add_patch(qk_rect)
ax.text(2.5, 4.3, r"$QK^{\top}$", ha="center", fontsize=12)
ax.text(2.5, -0.3, r"$L \times L$", ha="center", fontsize=10)

# Arrows
ax.arrow(1, 2, 0.4, 0, head_width=0.2, head_length=0.1, fc="black", ec="black")
ax.arrow(3.5, 2, 0.4, 0, head_width=0.2, head_length=0.1, fc="black", ec="black")

ax.set_xlim(-0.5, 5.5)
ax.set_ylim(-1, 5)
ax.axis("off")
plt.tight_layout()
plt.show()

Below is a **set of clean, lecture-ready TikZ diagrams** that illustrate the FlashAttention algorithm exactly as explained in:

üëâ [https://dev.to/lewis_won/flashattention-by-hand-34im](https://dev.to/lewis_won/flashattention-by-hand-34im)
üëâ Using **official TikZ syntax** ([https://tikz.dev](https://tikz.dev))

I provide **four diagrams**:

1. **Naive Attention (Quadratic Memory)**
2. **Tiling Q/K/V into Blocks**
3. **Streaming Softmax Logic (m, l updates)**
4. **Final Output Assembly**

Each diagram is ready for LaTeX slides.

---

# ‚úÖ **1. Naive Attention (Full QK·µÄ Materialization)**

### *Shows why attention is O(L¬≤) memory.*

```latex
\begin{tikzpicture}[>=latex, scale=1.0]

% Q vector
\draw[fill=blue!10] (0,0) rectangle (1,4);
\node at (0.5,4.3) {$Q$};
\node at (0.5,-0.3) {$L \times d$};

% K vector
\draw[fill=green!10] (4,0) rectangle (5,4);
\node at (4.5,4.3) {$K$};
\node at (4.5,-0.3) {$L \times d$};

% QK^T matrix
\draw[fill=orange!10] (1.5,0) rectangle (3.5,4);
\node at (2.5,4.3) {$QK^{\top}$};
\node at (2.5,-0.3) {$L \times L$};

% Arrows
\draw[->, thick] (1,2) -- (1.5,2);
\draw[->, thick] (3.5,2) -- (4,2);

\end{tikzpicture}
```

---

# ‚úÖ **2. FlashAttention: Tiling Q/K/V into Blocks**

### *Shows block-processing instead of full matrix.*

```latex
\begin{tikzpicture}[>=latex, scale=1.0]

% Q blocks
\foreach \i in {0,1,2} {
  \draw[fill=blue!10] (0,3-\i*1.3) rectangle (1,4-\i*1.3);
  \node at (0.5,3.65-\i*1.3) {$Q_{\i}$};
}

% K blocks
\foreach \i in {0,1,2} {
  \draw[fill=green!10] (2.5,3-\i*1.3) rectangle (3.5,4-\i*1.3);
  \node at (3,3.65-\i*1.3) {$K_{\i}$};
}

% V blocks
\foreach \i in {0,1,2} {
  \draw[fill=red!10] (5,3-\i*1.3) rectangle (6,4-\i*1.3);
  \node at (5.5,3.65-\i*1.3) {$V_{\i}$};
}

% Arrows
\foreach \i in {0,1,2} {
  \draw[->, thick] (1,3.5-\i*1.3) -- (2.5,3.5-\i*1.3);
  \draw[->, thick] (3.5,3.5-\i*1.3) -- (5,3.5-\i*1.3);
}

\node at (3,-1) {\Large Blocked Processing Instead of Full $QK^\top$};

\end{tikzpicture}
```

---

# ‚úÖ **3. Streaming Softmax (Online Softmax)**

### *The key algorithm of FlashAttention.*

This diagram illustrates how each block updates:

* ( m ) = running max
* ( l ) = running softmax denominator
* contributions to output ( O )

```latex
\begin{tikzpicture}[>=latex, scale=1.1]

% Block scores Sb
\draw[fill=orange!20] (0,0) rectangle (3,1);
\node at (1.5,0.5) {$S_b = Q_b K_b^{\top}$};

% Running max m_old -> m_new
\draw[->, thick] (3,0.5) -- (4,0.5);
\node at (4.7,0.5) {$m_{\text{new}} = \max(m_{\text{old}}, \max S_b)$};

% Running denominator
\draw[fill=blue!10] (0,-1.5) rectangle (3,-0.5);
\node at (1.5,-1) {$l_{\text{old}}$};

\draw[->, thick] (3,-1) -- (4,-1);
\node at (6,-1)
{
$\displaystyle
l_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} l_{\text{old}}
+ \sum e^{S_b - m_{\text{new}}}
$
};

% Probabilities
\draw[fill=green!10] (0,-3) rectangle (3,-2);
\node at (1.5,-2.5)
{
$\displaystyle
P_b = \frac{e^{S_b - m_{\text{new}}}}{l_{\text{new}}}
$
};

% Output accumulation
\draw[->, thick] (3,-2.5) -- (4,-2.5);
\node at (6,-2.5)
{
$\displaystyle
O \;+=\; P_b V_b
$
};

\end{tikzpicture}
```

---

# ‚úÖ **4. Final Assembly of Attention Output**

### *Combining block contributions into the final O.*

```latex
\begin{tikzpicture}[>=latex, scale=1.0]

% Outputs
\foreach \i in {0,1,2} {
  \draw[fill=purple!10] (0,3-\i*1.3) rectangle (2,4-\i*1.3);
  \node at (1,3.65-\i*1.3) {$P_{\i} V_{\i}$};
}

\draw[->, thick] (2,3.5) -- (3,2);
\draw[->, thick] (2,2.2) -- (3,2);
\draw[->, thick] (2,0.9) -- (3,2);

% Final sum
\draw[fill=purple!30] (3,1.5) rectangle (5,2.5);
\node at (4,2) {$O = \sum_b P_b V_b$};

\end{tikzpicture}
```

---

# üìò **These diagrams can be used directly in Beamer slides**

Example wrapper:

```latex
\begin{frame}{FlashAttention: Streaming Softmax}
\centering
\input{streaming_softmax_diagram.tex}
\end{frame}
```

---

# üéÅ **Would you like more diagrams?**

I can add:

* KV-cache growth over time
* Rotary embedding visualization
* GQA head grouping diagram
* Memory comparison: Naive vs FlashAttention
* Inference vs Training attention computation paths

Just tell me what you want for your lecture notes!
