<a href="https://colab.research.google.com/github/scalixte-mdsol/llm_inferences/blob/main/flash_paged_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# =========================
# Minimal attention showdown
# =========================
import math, time, torch, torch.nn.functional as F

# assert torch.cuda.is_available(), "Enable GPU in Colab: Runtime > Change runtime type > GPU"
# device = "cuda"
# dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
    dtype = torch.float32  # CPU prefers fp32 for correctness



## Flash Attention

**What the function is for (big picture)**

It computes causal self-attention for a full sequence (“prefill”) without ever materializing the huge S × S attention matrix.
Instead, it processes the sequence in tiles (blocks of queries and keys/values) and uses an online softmax trick to stitch results together numerically stably. This is the core idea behind FlashAttention—just written in simple PyTorch, not a fused CUDA kernel.

**Inputs & outputs**

q, k, v: tensors with shape (B, H, S, D)

B = batch size, H = number of heads, S = sequence length, D = head dimension.

causal=True: enforce “no peeking into the future.”

q_block, k_block: tile sizes for queries and keys/values.

Returns: tensor (B, H, S, D) — the attended values for every token.

In [None]:
# ---------- 1) Simple "flash-like" attention (prefill): tiling + online softmax ----------
@torch.no_grad()
def flash_attn_simple(q, k, v, *, causal=True, q_block=128, k_block=128):
    """
    Educational flash-like attention (no custom kernels).
    q,k,v: (B,H,S,D); returns (B,H,S,D)
    """
    B, H, S, D = q.shape

    # Prepare output & scaling. The scale is standard in attention: softmax(QKᵀ / √D) V.
    scale = 1.0 / math.sqrt(D)

    out = torch.zeros(B, H, S, D, device=q.device, dtype=q.dtype)

    # Loop over queries in tiles. We’ll compute attention for a small chunk of queries Qb at a time to save memory.
    for qs in range(0, S, q_block):
        qe = min(qs + q_block, S)
        q_blk = q[:, :, qs:qe, :]  # shape (B,H,Qb,D)

        # Running max/log-sum-exp accumulators for online softmax
        """
        What are qs, qe, ks, ke?

        You’re processing the sequence in tiles (blocks).
        qs:qe is the query tile’s global index range.
        ks:ke is the key/value tile’s global index range.
        So:
        Qb = qe - qs queries in this tile,
        Kb = ke - ks keys in this tile.

        Initialize running statistics for online softmax
        These three hold the softmax state as we accumulate contributions from many key/value tiles:
        m: the current max of attention scores seen so far (per query position).
        l: the denominator of softmax (sum of exponentials), rescaled by m.
        o: the weighted sum of values (numerator), also rescaled by m.
        """
        m = torch.full((B, H, qe - qs, 1), -float("inf"), device=q.device, dtype=q.dtype)
        l = torch.zeros(B, H, qe - qs, 1, device=q.device, dtype=q.dtype)
        o = torch.zeros(B, H, qe - qs, D, device=q.device, dtype=q.dtype)


        # Loop over keys/values in tiles. This computes a Qb × Kb slice of the big QKᵀ matrix.
        for ks in range(0, S, k_block):
            ke = min(ks + k_block, S)
            k_blk = k[:, :, ks:ke, :]
            v_blk = v[:, :, ks:ke, :]
            scores = torch.einsum("bhqd,bhkd->bhqk", q_blk, k_blk) * scale  # (B,H,Qb,Kb)

            # Apply causal masking (if requested). Any key to the right of a query (future) is masked out. This keeps generation valid.
            if causal:
                """
                For each query position (row) and key position (column), it checks: is this key in the future?
                If k_idx > q_idx, that key is to the right of (i.e., after) the query → mask it out.
                Because we use > (strictly greater), a token can attend to itself (k == q is allowed). Using >= would forbid self-attention, which we don’t want.

                """
                # mask future keys inside the current tiles
                """
                torch.arange(qs, qe) is [qs, qs+1, ..., qe-1] — the global positions for the queries in this tile.
                torch.arange(ks, ke) is [ks, ks+1, ..., ke-1] — the global positions for the keys in this tile.
                unsqueeze(-1) makes q_idx a column of shape (Qb, 1).
                unsqueeze(0) makes k_idx a row of shape (1, Kb).
                """
                q_idx = torch.arange(qs, qe, device=q.device).unsqueeze(-1)  # (Qb,1) # global positions for current queries
                k_idx = torch.arange(ks, ke, device=q.device).unsqueeze(0)  # (1,Kb)  # global positions for current keys
                # When you compare them, PyTorch uses broadcasting to form a full (Qb, Kb) matrix of comparisons.
                mask = (k_idx > q_idx)  # (Qb,Kb)  # # future keys relative to each query
                # The code lifts this to match the attention score tensor (B, H, Qb, Kb):
                scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))

            """
            Online softmax merge (numerical stability)
            This is the heart of the method:
            If you did softmax across all keys at once, you’d use
            softmax(scores) = exp(scores - max) / sum(exp(scores - max)).
            Here we only see a slice of keys at a time, so we:
            keep a running max m and running sums (o, l),
            rescale the old accumulators to the new max (m_new) before adding the new chunk.
            After processing all K/V tiles, (o / l) equals the same result you’d get from the full softmax.
            """

            m_new = torch.maximum(m, scores.amax(dim=-1, keepdim=True))
            exp_scores = torch.exp(scores - m_new)
            o = o * torch.exp(m - m_new) + torch.einsum("bhqk,bhkd->bhqd", exp_scores, v_blk)
            l = l * torch.exp(m - m_new) + exp_scores.sum(dim=-1, keepdim=True)
            m = m_new

        # Write the finished block to output. That’s the final softmax-weighted sum of V for those query positions.
        out[:, :, qs:qe, :] = o / l
    return out

**Why this works (and why it’s useful)**

Memory savings: We never build the full S × S attention matrix. We only hold small Qb × Kb pieces, so peak memory is much lower.

Numerical stability: The running-max (log-sum-exp) trick lets us merge tiles without losing precision or blowing up exponentials.

Causal correctness: Masking is done per tile using global positions, so the lower-triangular constraint is preserved across tiles.

**Tuning knobs (and trade-offs)**

q_block, k_block:

Smaller blocks → lower peak memory, but more loops → slower (especially on GPU, due to Python overhead and many kernel launches).

Larger blocks → faster, but need more memory.

@torch.no_grad():

This function is wrapped with no_grad (good for demos/inference/benchmarks).

If you want to train with it, remove @torch.no_grad() so gradients are tracked (but expect it to be much slower than fused kernels).

## Paged Attention

Paged attention is an inference-time way to run attention for decoder-only LLMs that’s optimized for long contexts and many concurrent requests. It combines two ideas:

Paged KV cache (memory layout):
Store each sequence’s keys/values (K/V) in fixed-size pages from a global pool. Each sequence has a page table (a list of page IDs). When a sequence grows, you append into its current tail page; if it fills, you allocate a new page. When the sequence finishes, you return its pages to the pool.
Why: avoids big memcopies and fragmentation, and enables “continuous batching” of requests with different lengths.

Paged decode (compute):
For each new token (so one query per head), the kernel walks the page table and computes attention page by page instead of over the whole context at once. It uses an online (streaming) softmax to merge each page’s contribution stably, so you never need to materialize the full logits vector.
Why: keeps peak activations low and scales well as context grows.


### How a decode step works (conceptually)

For a single new query $q$:

1. Look up the sequence’s **page IDs**: `[p0, p1, …, pN]`.
2. For each page `pi`:

   * Load that page’s `K_i, V_i`.
   * Compute logits $q K_i^\top / \sqrt{d}$.
   * Update running softmax **max/denominator/numerator** (log-sum-exp trick) with this page’s chunk.
3. After all pages, output $\text{softmax}(qK^\top)V$ from the accumulated numerator/denominator.
4. Append the new token’s K/V to the sequence’s **tail page** (allocate a new page if the tail is full).

> The **math result** is the same as computing attention over the whole context; it’s just streamed in pages for memory and scheduling efficiency.



### Why it’s useful

* **Continuous batching:** Mix sequences of very different lengths without copying K/V around.
* **Long context friendly:** O(S) memory for activations during decode (only one page in flight).
* **High throughput serving:** Less memory churn, fewer large reallocations.



### How it differs from other terms

* **Paged attention vs. Flash attention:**
  *Flash* is a fused kernel that speeds up full attention inside a block (great for **prefill**). *Paged attention* is about how you **store** K/V and **stream** compute during **decode**. Systems often use both: Flash for prefill; Paged for decode.
* **Paged decode vs. Paged KV cache:**
  *Paged decode* = the **compute** strategy (process K/V in chunks with online softmax).
  *Paged KV cache* = the **memory** strategy (fixed-size pages + allocator + page tables).
  In production, kernels read **directly from pages** while doing the paged decode.


### Tiny pseudocode (decode for one new token)

```python
# page_table: list of page IDs for this sequence
m = -inf         # running max logit
num = 0          # running numerator
den = 0          # running denominator
for page_id in page_table:
    K_page, V_page = load_page(page_id)           # (H, PAGE, D)
    logits = (q @ K_page.T) / sqrt(D)             # (H, PAGE)
    cur_max = max_over_page(logits)               # (H, 1)
    new_max = max(m, cur_max)
    num = num * exp(m - new_max) + exp(logits - new_max) @ V_page
    den = den * exp(m - new_max) + sum(exp(logits - new_max), axis=page)
    m = new_max
out = num / den                                   # (H, D)
```

### Practical notes

* **Page sizes:** The **storage** page size (for KV cache) and the **compute** chunk size don’t have to be equal, but implementations pick sizes that map well to GPU tiles.
* **Libraries:** vLLM/FlashInfer implement paged attention with **fused CUDA kernels** and a KV allocator; PyTorch SDPA can be used for prefill (Flash backend) and for simple decode, but the full paging system is what unlocks serving efficiency at scale.

**Summary:** *Paged attention streams attention over a KV cache organized in fixed-size pages, letting servers handle very long, many, and variable-length sequences efficiently during decode.*



* **Flash Attention** = a **turbo-charged kernel** to compute attention fast and memory-efficiently when you already have contiguous Q/K/V for a block. It’s about **how the math is computed** (fused, tiled on-chip, online softmax).
* **Paged Attention** = a **serving strategy** for the **decode step** that organizes and reads the **KV cache** in fixed-size **pages** so many sequences of different lengths can be handled **without copying**. It’s about **how K/V are stored and streamed** during inference.

## Side-by-side (at a glance)

| Dimension               | **Flash Attention**                                                                              | **Paged Attention**                                                                                                          |
| ----------------------- | ------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------- |
| Main goal               | Make the attention **computation** itself fast & low-memory (fused kernel).                      | Make **inference with long/variable contexts** efficient via a paged **KV cache** and streaming compute.                     |
| Used when               | **Prefill** (full S×S) and sometimes training; also works for decode if you have contiguous K/V. | **Decode** (1 new token) across long history and many concurrent requests.                                                   |
| What it optimizes       | **Kernel efficiency**: fuses QKᵀ → softmax → (·V), tiles in SRAM, online softmax.                | **Memory layout & scheduling**: split K/V into **pages**, keep a **page table** per sequence, reuse pages, avoid K/V copies. |
| Input layout assumption | Typically **contiguous** Q/K/V for the current block.                                            | K/V can be **non-contiguous**, scattered across pages; kernel **walks the page table**.                                      |
| Memory during compute   | Very low **activation** memory for full attention (doesn’t materialize S×S).                     | Low **working** memory at decode (process one page at a time); total K/V stored in pages.                                    |
| Pairing                 | Orthogonal to paging; can be used inside each chunk/page.                                        | Often paired with a fused kernel (flash-style) while streaming pages.                                                        |

## A concrete timeline (what happens in a server)

1. **Prefill** (first prompt tokens):

   * You usually run **Flash Attention** (or PyTorch SDPA “flash” backend).
   * This computes full S×S attention super fast and stores K/V into the **paged KV cache** (pages allocated as needed).

2. **Decode** (generate next tokens):

   * For each new token, the kernel performs **Paged Attention**:

     * walks the sequence’s **page table**,
     * reads each K/V **page** in order,
     * does attention **page-by-page** with an **online softmax**,
     * appends the new token’s K/V to the tail page (allocate a new one if full).
   * Inside each page’s compute, you can still use **flash-style fused math**.

## Two quick mental models

* **Flash = race car engine** (how fast you compute attention on the data in front of you).
* **Paged = highway & exits** (how you lay out and fetch the data for many trips without traffic jams/copies).

## Minimal code pointers

* **Flash (compute):**

  ```python
  with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
      out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
  ```
* **Paged decode (compute streaming idea, assuming contiguous K/V):**
  (your `paged_decode(q1, k, v, page=256)`—process K/V in chunks and merge with online softmax)
* **Paged KV cache (memory layout):**
  Real systems (e.g., vLLM) keep K/V in global **pages** with a per-sequence **page table** and an **allocator**. The fused kernel reads directly from those pages (no Python concat) *while* doing the online softmax.

## Summary

  * **Flash Attention**: *how to compute it fast* (kernel fusion/tiling).
  * **Paged Attention**: *how to store & stream K/V at decode* (pages + page tables + online softmax).
* They’re **complementary**: Flash for speed, Paged for scalable serving—used together in modern LLM inference.


### Page Size

page=256 is the page size—how many past tokens’ K/V you process at a time in the loop.

So with page=256 and context length S:
- The loop runs ceil(S / 256) times.
- Each iteration computes logits for 256 keys (except the last, which may be smaller), then merges its softmax contribution
into the running total using the log-sum-exp trick.
   

**What does the page size change?**

1) Peak memory (temporary activations)

Bigger page ⇒ larger temporary tensors (e.g., logits of shape (B, H, 1, page)), so higher peak memory.

Smaller page ⇒ smaller temporaries ⇒ lower peak memory.

2) Speed / kernel launch overhead

Bigger page ⇒ fewer loop iterations ⇒ fewer kernel launches ⇒ usually faster.

Smaller page ⇒ more iterations ⇒ more launches / Python overhead ⇒ usually slower.

3) Correctness

Results are (up to floating-point roundoff) the same regardless of page. The online softmax merge makes paging mathematically equivalent to processing all keys at once.


**Rules of thumb**

Start with page=256. It’s a good balance on most GPUs for moderate head dims (e.g., D≈64/80/128).

If you hit OOM or want to push to longer contexts, drop to 128 (or even 64).

If you have headroom and want a bit more speed, try 512 (or 1024 on bigger GPUs).

Make page a multiple of 64 or 128 (often maps better to GPU tile sizes).



**Edge cases to remember**

If S ≤ page, you effectively do one chunk—fastest and highest per-call memory.

The last chunk can be shorter than page; the code already handles that.

This function is for decode (1 new token). For prefill (full S×S), use SDPA “flash” instead; paging K/V matters much less there.

Bottom line: page=256 is just the chunk size for K/V. Increase it for speed (if you have memory), decrease it for memory savings (at some speed cost).

Exercise goal: **compute attention for one new token** (the query `q1`) over a **long past context** (`k`, `v`) **in pages** so we keep memory small, while getting **the exact same result** as full softmax.
The function below streams through keys/values in **chunks (“pages”)** and uses an **online softmax** to merge each chunk’s contribution **stably**—so you never build the full logits vector.

#### Inputs / Output (shapes)

* `q1`: **(B, H, 1, D)** — one new query token per batch/head
* `k`, `v`: **(B, H, S, D)** — cached keys/values for **S** past tokens
* `page`: e.g. **256** — how many past tokens to process per loop
* Returns: **(B, H, 1, D)** — the attended value for that new token

> B = batch, H = heads, S = context length, D = head size


```python
B, H, _, D = q1.shape
S = k.shape[2]
scale = 1.0 / math.sqrt(D)
```

* Standard attention scale $1/\sqrt{D}$.

We keep **running softmax state** across pages:

```python
max_log = None  # running max logit  (B,H,1,1)
num = None      # running numerator  (B,H,1,D)
den = None      # running denominator(B,H,1,1)
```

Why? If you had all logits $z$ at once, softmax uses:

$$
\frac{\sum_i e^{z_i} V_i}{\sum_i e^{z_i}}
\quad\text{(implemented stably as }e^{z_i - m}\text{ with }m=\max z\text{)}
$$

Paging means we don’t see all $z$ at once, so we **accumulate** numerator/denominator with a **running max** for numerical stability.

When we stream pages, we don’t know the global max yet. So we keep a running max m and rebase old sums whenever we see a page with a higher max.

What exactly is stored in m?

Shape: (B, H, 1, 1) — one value per batch and head (for the single new query).

Semantics: the maximum logit seen so far across all pages processed.

Now we loop over pages of size `page`:

```python
for s in range(0, S, page):
    kk = k[:, :, s:s+page, :]  # (B,H,chunk,D)
    vv = v[:, :, s:s+page, :]  # (B,H,chunk,D)
```

Compute this page’s logits and its local max:

```python
logits = torch.matmul(q1, kk.transpose(-2, -1)) * scale  # (B,H,1,chunk)
cur_max = logits.amax(dim=-1, keepdim=True)              # (B,H,1,1)
```

#### First page: initialize the running state

```python
if max_log is None:
    max_log = cur_max
    e = torch.exp(logits - max_log)         # (B,H,1,chunk)
    num = e @ vv                            # (B,H,1,D)
    den = e.sum(dim=-1, keepdim=True)       # (B,H,1,1)
```

This is just “softmax over what we’ve seen so far” (the first chunk).

#### Next pages: merge with **online softmax**

```python
else:
    new_max = torch.maximum(max_log, cur_max)
    # rebase old accumulators from max_log ->
    """
    This rebasing is the key: if the new page has a larger max, we scale down the old accumulators so everything is measured relative to the new (bigger) max, keeping exponentials numerically safe.
    """
    num = num * torch.exp(max_log - new_max) + (torch.exp(logits - new_max) @ vv)
    den = den * torch.exp(max_log - new_max) + torch.exp(logits - new_max).sum(dim=-1, keepdim=True)
    max_log = new_max
```

**What’s happening:**
Suppose old accumulators were based on $m_{\text{old}}$ and the new page has max $m_{\text{page}}$. For stability, switch to $m_{\text{new}}=\max(m_{\text{old}}, m_{\text{page}})$. Then

$$
\sum e^{z - m_{\text{new}}}
= e^{m_{\text{old}} - m_{\text{new}}}\!\!\sum_{\text{old}}\! e^{z - m_{\text{old}}}
+ \sum_{\text{page}} e^{z - m_{\text{new}}}
$$

Same scaling for the numerator $\sum e^{z - m} V$. That’s exactly what those lines implement.

After all pages:

```python
return num / den   # (B,H,1,D)
```

That’s the **softmax-weighted sum of V over the entire context**, identical to doing it in one shot—just streamed.


#### Tiny numeric example

Suppose after page 1:

* `max_log = 5`
* `den = ∑ e^{z_old - 5}`
* `num = ∑ e^{z_old - 5} V_old`

Page 2 has `cur_max = 7` → `new_max = 7`. To merge:

* Rescale old terms by `e^{5-7}`:

  * `den ← den * e^{5-7} + ∑ e^{z_new - 7}`
  * `num ← num * e^{5-7} + ∑ e^{z_new - 7} V_new`
* Set `max_log = 7`.

At the end, the output is `num / den`, which equals the full softmax result.

`m` is the **running per-(B,H) max logit** used to do a **stable, streaming softmax** across pages. It lets you combine chunks without ever building all logits at once, while avoiding numerical issues.

#### Intuition (no equations)

* Think of having **100,000** past tokens. Instead of reading all keys at once, you read **a few hundred at a time**.
* For each chunk, you compute “how much this chunk contributes” and **blend** it with what you already had.
* The **running max** lets you keep numbers well-scaled so the softmax doesn’t blow up or underflow.


#### Why no causal mask here?

This is **decode**: K/V contain **only past tokens** for the current position. There are no “future” tokens to mask, so `is_causal` isn’t needed.


#### What does `page=256` change?

* **Bigger** page (e.g., 512/1024) → fewer iterations, usually **faster**, but **higher** temporary memory.
* **Smaller** page (e.g., 64/128) → lower peak memory, but more loops/overhead → **slower**.


#### Micro example (one head, tiny numbers)

Say logits for all S keys were `[2, 1, 3, 0]`, and `page=2`.

* Page 1 logits: `[2, 1]` → init `num, den` with those.
* Page 2 logits: `[3, 0]` → compute its exponentials **relative to a new overall max (3)**, **rescale** old `num/den` to the new max, then add this page’s contribution.
* Final `num/den` equals softmax over `[2,1,3,0]` in one shot.


#### Summary

* **What:** Compute attention for a single new token by **streaming** K/V in pages and **merging softmax** chunk by chunk.
* **Why:** Keep **peak memory low**, perfect for **long contexts** at decode.
* **Correctness:** Same result as full softmax, thanks to the **online (log-sum-exp) merge**.


In [None]:
# ---------- 2) Paged attention (decode): 1 new token attends over K/V in pages ----------
@torch.no_grad()
def paged_decode(q1, k, v, page=256):
    # Online softmax over pages: O(S) memory, low peak activation.
    B, H, _, D = q1.shape  # running max of logits (per B,H,1) to keep exps stable
    S = k.shape[2] # running numerator: sum_j exp(logit_j - max) * V_j
    scale = 1.0 / math.sqrt(D)

    max_log = None
    num = None
    den = None

    for s in range(0, S, page):
        kk = k[:, :, s:s+page, :]
        vv = v[:, :, s:s+page, :]
        logits = torch.matmul(q1, kk.transpose(-2, -1)) * scale  # (B,H,1,chunk)
        cur_max = logits.amax(dim=-1, keepdim=True)              # (B,H,1,1)

        if max_log is None:
            max_log = cur_max
            e = torch.exp(logits - max_log)
            num = e @ vv
            den = e.sum(dim=-1, keepdim=True)
        else:
            new_max = torch.maximum(max_log, cur_max)
            # merge old accumulators with new chunk (log-sum-exp trick)
            num = num * torch.exp(max_log - new_max) + (torch.exp(logits - new_max) @ vv)
            den = den * torch.exp(max_log - new_max) + torch.exp(logits - new_max).sum(dim=-1, keepdim=True)
            max_log = new_max
    return num / den

What I showed earlier (`paged_decode(q1, k, v, page=256)`) **doesn’t implement that memory system**. In my demo:

* `page=256` is just a **compute chunk size**: we iterate through a *contiguous* `k, v` tensor in slices to keep activations small.
* Real “paged attention” adds a **KV cache layout + allocator** on top of the math so you **don’t even need K/V to be contiguous per sequence**.

Here’s the difference, side by side:

#### 1) What our demo does (compute paging)

* Input: contiguous `k, v` of shape `(B, H, S, D)`.
* We loop over `s : s+page` to **compute** softmax in chunks and merge with online softmax.
* Pros: simple, teaches the math.
* Limitation: assumes you already **have** K/V for the sequence in one contiguous block.

#### 2) What real paged KV cache does (memory paging)

* Global K/V storage is a big pool split into **fixed-size pages**, e.g.:

  ```
  kv_cache_k: (NUM_PAGES, H, PAGE_SIZE, D)
  kv_cache_v: (NUM_PAGES, H, PAGE_SIZE, D)
  ```
* Each sequence keeps a **page table** listing which pages hold its tokens, plus how many slots are filled in the last page:

  ```
  seq.page_ids = [p0, p1, p2, ...]     # which pages belong to this seq
  seq.tail_filled = r                   # 0..PAGE_SIZE-1
  ```
* **Append**: write new tokens into the current tail page; when full, **allocate** a fresh page from a free list and append its id to `page_ids`.
* **Finish**: when the sequence ends, **return** its pages to the free list (reuse for other requests).
* **Continuous batching**: Because every sequence has its own page table, you can batch requests with **different lengths** without padding or copying large K/V blocks.
* **Attention read**: kernels read K/V **directly from the listed pages** (no concat/copy), computing attention over a *logical* contiguous stream formed by those pages.


> In **real** paged attention (vLLM/FlashInfer), step **(1)** is *not done with Python concat*. The fused CUDA kernel:
>
> * walks the **page table**,
> * reads from each page in order,
> * does the softmax-and-weighted-sum **on the fly**,
>   so you avoid copies **and** keep memory + latency low.

#### Where the **dynamic allocation/reuse** lives

* A tiny **allocator** manages a `free_pages` list.
* On append:

  * If the current tail page has space, write into it.
  * Else pop one page from `free_pages`, push its id into the sequence’s `page_ids`.
* On finish/cancel: push all `page_ids` back into `free_pages`.
* This is how servers **continuously batch** requests that grow to different lengths without expensive K/V copies.


#### Summary

* My earlier `paged_decode(q1, k, v, page=256)` showed **compute paging** (small-chunk softmax merge) assuming contiguous K/V.
* Real **paged KV cache** also includes a **memory layout + allocator + page tables** so sequences of different lengths can **share, extend, and free** K/V **without copying**.
* Production systems use **fused kernels** that read directly from those pages while doing the **same online-softmax math** you saw in the demo.


In [None]:
# ---------- 3) SDPA wrapper (PyTorch built-in) ----------
def sdpa(q, k, v, *, causal=True, attn_mask=None):
    return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal)


In [None]:
# ---------- 4) Quick benchmark ----------
def time_cuda(fn, warmup=5, iters=20):
    # warmup
    for _ in range(warmup): fn()
    torch.cuda.synchronize()
    start = torch.cuda.Event(True); end = torch.cuda.Event(True)
    start.record()
    for _ in range(iters): fn()
    end.record(); torch.cuda.synchronize()
    return start.elapsed_time(end) / iters  # ms/call

torch.manual_seed(0)
"""
Convenience hyperparameters you’ll use to shape tensors:

B = batch size

H = number of attention heads

S = number of tokens in the sequence (how many time steps you’re attending over)

D = per-head hidden size (head dimension)

"""
B, H, D = 1, 8, 64

In [None]:
# -- Prefill (full SxS): compare SDPA vs simple flash-like --
SEQ_LIST = [128, 256, 512, 1024]
print("=== Prefill (full attention) — SDPA vs flash_attn_simple ===")
for S in SEQ_LIST:
    q = torch.randn(B,H,S,D, device=device, dtype=dtype)
    k = torch.randn(B,H,S,D, device=device, dtype=dtype)
    v = torch.randn(B,H,S,D, device=device, dtype=dtype)

    t_sdpa = time_cuda(lambda: sdpa(q,k,v, causal=True), warmup=3, iters=10)
    t_flash = time_cuda(lambda: flash_attn_simple(q,k,v, causal=True, q_block=128, k_block=128),
                        warmup=2, iters=5)  # a bit heavier in Python, fewer iters

    tokps_sdpa  = (B*S) / (t_sdpa/1e3)   # tokens/sec consumed
    tokps_flash = (B*S) / (t_flash/1e3)
    print(f"S={S:4d} | SDPA: {t_sdpa:6.2f} ms  ({tokps_sdpa:8.0f} tok/s)  |  flash-simple: {t_flash:6.2f} ms  ({tokps_flash:8.0f} tok/s)")

=== Prefill (full attention) — SDPA vs flash_attn_simple ===
S= 128 | SDPA:   0.30 ms  (  422913 tok/s)  |  flash-simple:   0.60 ms  (  212832 tok/s)
S= 256 | SDPA:   0.31 ms  (  821946 tok/s)  |  flash-simple:   2.03 ms  (  125931 tok/s)
S= 512 | SDPA:   0.88 ms  (  579729 tok/s)  |  flash-simple:   7.72 ms  (   66363 tok/s)
S=1024 | SDPA:   2.70 ms  (  379701 tok/s)  |  flash-simple:  29.67 ms  (   34514 tok/s)


In [None]:
# -- Decode (1 new token over context S): compare SDPA vs paged_decode --
CTX_LIST = [512, 1024, 2048, 4096]
print("\n=== Decode (1 token) — SDPA vs paged_decode (page=256) ===")
for S in CTX_LIST:
    q1 = torch.randn(B,H,1,D, device=device, dtype=dtype)
    k = torch.randn(B,H,S,D, device=device, dtype=dtype)
    v = torch.randn(B,H,S,D, device=device, dtype=dtype)

    t_sdpa = time_cuda(lambda: sdpa(q1,k,v, causal=False), warmup=5, iters=50 if S<=2048 else 25)
    t_paged = time_cuda(lambda: paged_decode(q1,k,v, page=256), warmup=5, iters=50 if S<=2048 else 25)

    tokps_sdpa  = 1 / (t_sdpa/1e3)   # tokens/sec generated (1 token per call)
    tokps_paged = 1 / (t_paged/1e3)
    print(f"S={S:4d} | SDPA: {t_sdpa:7.3f} ms  ({tokps_sdpa:8.1f} tok/s)  |  paged: {t_paged:7.3f} ms  ({tokps_paged:8.1f} tok/s)")



=== Decode (1 token) — SDPA vs paged_decode (page=256) ===
S= 512 | SDPA:   0.188 ms  (  5306.2 tok/s)  |  paged:   0.460 ms  (  2174.6 tok/s)
S=1024 | SDPA:   0.184 ms  (  5421.3 tok/s)  |  paged:   1.061 ms  (   942.8 tok/s)
S=2048 | SDPA:   0.259 ms  (  3855.7 tok/s)  |  paged:   2.513 ms  (   397.9 tok/s)
S=4096 | SDPA:   0.405 ms  (  2470.2 tok/s)  |  paged:   4.528 ms  (   220.9 tok/s)


### For folks who don't have GPU, please use CPU version of the code below

In [None]:
# ==== CPU prefill benchmark: SDPA vs flash_attn_simple ====
# CPU setup
device = "cpu"
dtype = torch.float32
torch.manual_seed(0)
# Optional: stabilize timings by using 1 thread (or set a fixed number you prefer)
# torch.set_num_threads(1)

def time_cpu(fn, warmup=3, iters=10):
    # Warmup
    for _ in range(warmup):
        fn()
    # Measure
    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    t1 = time.perf_counter()
    return (t1 - t0) * 1000.0 / iters  # ms per call

# Minimal SDPA wrapper (CPU path uses PyTorch's math kernel)
def sdpa(q, k, v, *, causal=True):
    return F.scaled_dot_product_attention(q, k, v, is_causal=causal)

# ---- Benchmark params (keep modest on CPU) ----
B, H, D = 1, 8, 64
SEQ_LIST = [64, 128, 256, 512]  # reduce or increase as your CPU allows

print("=== Prefill (full attention, CPU) — SDPA vs flash_attn_simple ===")
for S in SEQ_LIST:
    q = torch.randn(B, H, S, D, device=device, dtype=dtype)
    k = torch.randn(B, H, S, D, device=device, dtype=dtype)
    v = torch.randn(B, H, S, D, device=device, dtype=dtype)

    # SDPA timing
    t_sdpa = time_cpu(lambda: sdpa(q, k, v, causal=True), warmup=2, iters=5)

    # flash_attn_simple timing (Python loops; slower on CPU)
    # Use smaller blocks on CPU to keep memory in check; speed will still be lower than SDPA
    t_flash = time_cpu(
        lambda: flash_attn_simple(q, k, v, causal=True, q_block=64, k_block=64),
        warmup=1, iters=2
    )

    tokps_sdpa  = (B * S) / (t_sdpa / 1e3)   # tokens/sec "consumed"
    tokps_flash = (B * S) / (t_flash / 1e3)

    print(f"S={S:4d} | SDPA: {t_sdpa:7.2f} ms  ({tokps_sdpa:8.0f} tok/s)  |  flash-simple: {t_flash:7.2f} ms  ({tokps_flash:8.0f} tok/s)")


=== Prefill (full attention, CPU) — SDPA vs flash_attn_simple ===
S=  64 | SDPA:    0.21 ms  (  301979 tok/s)  |  flash-simple:    1.35 ms  (   47395 tok/s)
S= 128 | SDPA:    0.75 ms  (  170197 tok/s)  |  flash-simple:    3.21 ms  (   39891 tok/s)
S= 256 | SDPA:    3.05 ms  (   83884 tok/s)  |  flash-simple:   13.77 ms  (   18592 tok/s)
S= 512 | SDPA:   12.63 ms  (   40527 tok/s)  |  flash-simple:   52.15 ms  (    9817 tok/s)


In [None]:
# ---- Benchmark params (keep modest on CPU) ----
B, H, D = 1, 8, 64
CTX_LIST = [256, 512, 1024]   # increase if your CPU is fast

print("\n=== Decode (CPU): SDPA vs paged_decode (page=128) ===")
for S in CTX_LIST:
    q1 = torch.randn(B, H, 1, D, device=device, dtype=dtype)      # one new token
    k  = torch.randn(B, H, S, D, device=device, dtype=dtype)      # past keys
    v  = torch.randn(B, H, S, D, device=device, dtype=dtype)      # past values

    t_sdpa  = time_cpu(lambda: sdpa(q1, k, v, causal=False), warmup=3, iters=20)
    t_paged = time_cpu(lambda: paged_decode(q1, k, v, page=128), warmup=3, iters=20)

    tokps_sdpa  = 1 / (t_sdpa / 1e3)   # tokens/sec (1 token per call)
    tokps_paged = 1 / (t_paged / 1e3)

    print(f"S={S:4d} | SDPA: {t_sdpa:7.2f} ms  ({tokps_sdpa:7.1f} tok/s)  |  paged: {t_paged:7.2f} ms  ({tokps_paged:7.1f} tok/s)")


=== Decode (CPU): SDPA vs paged_decode (page=128) ===
S= 256 | SDPA:    0.06 ms  (16257.3 tok/s)  |  paged:    0.34 ms  ( 2919.5 tok/s)
S= 512 | SDPA:    0.12 ms  ( 8472.3 tok/s)  |  paged:    0.66 ms  ( 1509.1 tok/s)
S=1024 | SDPA:    0.23 ms  ( 4316.8 tok/s)  |  paged:    1.34 ms  (  747.9 tok/s)


### Exercise (after running functions above)
- Try different page sizes (128/256/512). Log the observations. Which is faster on your GPU? What trade-off does page control?
-  In prefill, how does ms scale as S doubles (≈4× slower or ≈2× slower)? Why
- In decode, does ms/token grow ~linearly with S for both SDPA and Paged?
-  Tweak these:
B_list = [1, 2],
H_list = [8, 16],
D_list = [64, 128]
and report observations
- For folks who can use both GPU and CPU version, what observations do you see?