In [None]:
import os
import sys
import subprocess

REPO_URL = "https://github.com/openai/llm-guesstimator.git"

bash_setup = f"""
setup_llm_guesstimator() {{
  if [[ -n \"${{COLAB_RELEASE_TAG:-}}\" ]]; then
    local repo_path=\"/content/llm-guesstimator\"
    if [[ ! -d \"$repo_path\" ]]; then
      git clone \"{REPO_URL}\" \"$repo_path\"
    fi
  fi
}}
setup_llm_guesstimator
"""

subprocess.run(["bash", "-lc", bash_setup])

repo_path = "/content/llm-guesstimator" if "google.colab" in sys.modules else os.path.abspath('.')

if repo_path not in sys.path:
    sys.path.append(repo_path)
os.chdir(repo_path)


# LLM Prefill & Decode: Compute vs Memory Boundness — Corrected Formulas, Intuition, and Plots

This notebook consolidates the corrected equations for **prefill** (encode) and **decode** (per new token) phases of Transformer-based LLM inference, derives **arithmetic intensity vs machine balance** sanity checks, and visualizes **compute time**, **memory time**, and **max(time)** across relevant ranges of sequence length $L$ and batch size $S$.

We follow the convention that **every matmul costs $2mnp$ FLOPs** (multiply + add).

## Parameters (Default Values)

- $S$: concurrent requests at the step (batch size)
- $L$: context length (prefill) / past tokens (decode)
- $P = S \cdot L$: "token budget"
- $d$: hidden size
- $r$: FFN expansion ratio ($\approx 4$)
- $n_\ell$: number of Transformer layers
- $\text{dtype}_B$: bytes per element (e.g., 2 for FP16/BF16)
- $\mathcal{F}$: sustained GPU FLOPs/s (e.g., A100-40GB FP16 TC $\approx 3.12\times 10^{14}$)
- $BW$: device memory bandwidth in B/s (e.g., A100-40GB $\approx 1.555\times 10^{12}$)
- $c_{\text{act}}$: activation I/O multiplier (use 12)

We also define the **machine balance** $\beta = \mathcal{F} / BW$ (FLOPs per byte).

## Corrected Formulas

### Prefill (encode) — per step
**Compute FLOPs (per layer, across $S$ requests):**

$$(8+4r) L d^2 + 4 L^2 d$$

**Compute time:**

$$T_{\text{prefill}}^{\text{compute}}(S,L) = \frac{n_\ell S \big((8+4r) L d^2 + 4 L^2 d\big)}{\mathcal{F}}$$

**Memory bytes:**

$$\text{Bytes}_{\text{prefill}} = n_\ell \Big((4+2r) d^2 + (2+c_{\text{act}}) S L d \Big) \, \text{dtype}_B$$

**Memory time:**

$$T_{\text{prefill}}^{\text{memory}}(S,L) = \frac{n_\ell \Big((4+2r) d^2 + (2+c_{\text{act}}) S L d \Big) \, \text{dtype}_B}{BW}$$

---

### Decode (per new token) — per step
**Compute FLOPs (per layer, across $S$ requests):**

$$(8+4r) d^2 + 4 L d$$

**Compute time per token:**

$$T_{\text{token}}^{\text{compute}}(S,L) = \frac{n_\ell S \big((8+4r) d^2 + 4 L d\big)}{\mathcal{F}}$$

**Memory bytes (per token, across all $S$ and layers):**

$$\text{Bytes}_{\text{token}} = n_\ell \Big((4+2r) d^2 + (2SL + (2+c_{\text{act}})S) d \Big) \, \text{dtype}_B$$

**Memory time per token:**

$$T_{\text{token}}^{\text{memory}}(S,L) = \frac{n_\ell \Big((4+2r) d^2 + (2SL + (2+c_{\text{act}})S) d \Big) \, \text{dtype}_B}{BW}$$

## Intuition & Insights

- **Prefill** grows $O(L^2)$ due to attention; tends compute-bound for large $L$.
- **Decode** grows linearly in $L$; KV-cache reads dominate memory, so decode tends memory-bound beyond short contexts.
- $S$ scales both compute and memory similarly, so $L$ mainly drives boundness.
- **Arithmetic Intensity (AI)** $= \text{FLOPs}/\text{Bytes}$: if $AI > \beta$ → compute-bound; else memory-bound.
- For A100-40GB, $\beta \approx 200$ FLOPs/byte.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from utils.config import GRID_SETTINGS, get_hardware_config, get_model_config
from utils.math_utils import (
    decode_compute_time,
    decode_memory_time,
    prefill_compute_time,
    prefill_memory_time,
)

hardware = get_hardware_config()
prefill_settings = GRID_SETTINGS["prefill"]
decode_settings = GRID_SETTINGS["decode"]


## Prefill Plots

In [None]:
cfg = get_model_config("70B")
d, n_layers, r = cfg.hidden_size, cfg.num_layers, cfg.expansion_ratio
L = np.arange(prefill_settings["surface_context_range"][0], prefill_settings["surface_context_range"][1] + prefill_settings["surface_context_range"][2], prefill_settings["surface_context_range"][2])
for S in [1, 16]:
    Tc = prefill_compute_time(S, L, cfg, hardware)
    Tm = prefill_memory_time(S, L, cfg, hardware)
    plt.plot(L, Tc, label='T_compute')
    plt.plot(L, Tm, label='T_memory')
    plt.plot(L, np.maximum(Tc, Tm), label='max')
    plt.title(f'Prefill Times S={S}')
    plt.legend()
    plt.show()


## Decode Plots

In [None]:
cfg = get_model_config("70B")
L = np.arange(decode_settings["surface_past_length_range"][0], decode_settings["surface_past_length_range"][1] + decode_settings["surface_past_length_range"][2], decode_settings["surface_past_length_range"][2])
for S in [1, 32]:
    Tc = decode_compute_time(S, L, cfg, hardware)
    Tm = decode_memory_time(S, L, cfg, hardware)
    plt.plot(L, Tc, label='T_compute')
    plt.plot(L, Tm, label='T_memory')
    plt.plot(L, np.maximum(Tc, Tm), label='max')
    plt.title(f'Decode Times S={S}')
    plt.legend()
    plt.show()


In [None]:
import plotly.graph_objs as go

from utils.config import prefill_surface_grids, decode_surface_grids
from utils.math_utils import safe_ratio

cfg = get_model_config("70B")

prefill_S_vals, prefill_L_range = prefill_surface_grids()
prefill_S_vals = np.array(prefill_S_vals, dtype=float)
prefill_L_vals = np.arange(
    prefill_L_range[0],
    prefill_L_range[1] + prefill_L_range[2],
    prefill_L_range[2],
    dtype=float,
)

prefill_Sg = prefill_S_vals[None, :]
prefill_Lg = prefill_L_vals[:, None]

prefill_T_compute = prefill_compute_time(prefill_Sg, prefill_Lg, cfg, hardware)
prefill_T_memory = prefill_memory_time(prefill_Sg, prefill_Lg, cfg, hardware)
prefill_T_ratio = safe_ratio(prefill_T_compute, prefill_T_memory)

prefill_eps = 0.05
prefill_den = np.maximum(prefill_T_compute, prefill_T_memory)
prefill_boundary = np.abs(prefill_T_compute - prefill_T_memory) / np.where(prefill_den == 0, 1, prefill_den) <= prefill_eps
pi, pj = np.where(prefill_boundary)
prefill_boundary_L = prefill_L_vals[pi]
prefill_boundary_S = prefill_S_vals[pj]
prefill_boundary_Z = prefill_T_compute[pi, pj]

def make_surface(x_vals, y_vals, z, boundary_S, boundary_L, boundary_Z, title, zlabel):
    fig = go.Figure(
        data=[
            go.Surface(x=x_vals, y=y_vals, z=z, showscale=True),
            go.Scatter3d(
                x=boundary_S,
                y=boundary_L,
                z=boundary_Z,
                mode="markers",
                marker=dict(size=3, opacity=0.8),
            ),
        ]
    )
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title="S (batch size)",
            yaxis_title="L (tokens)",
            zaxis_title=zlabel,
        ),
        margin=dict(l=0, r=0, t=40, b=0),
    )
    fig.show()

make_surface(
    prefill_S_vals,
    prefill_L_vals,
    prefill_T_compute,
    prefill_boundary_S,
    prefill_boundary_L,
    prefill_boundary_Z,
    "Prefill — T_compute(S,L)",
    "Time per prefill step (s)",
)

make_surface(
    prefill_S_vals,
    prefill_L_vals,
    prefill_T_memory,
    prefill_boundary_S,
    prefill_boundary_L,
    prefill_boundary_Z,
    "Prefill — T_memory(S,L)",
    "Time per prefill step (s)",
)

make_surface(
    prefill_S_vals,
    prefill_L_vals,
    prefill_T_ratio,
    prefill_boundary_S,
    prefill_boundary_L,
    prefill_boundary_Z,
    "Prefill — T_compute/T_memory(S,L)",
    "Ratio",
)


In [None]:
import plotly.graph_objs as go

cfg = get_model_config("70B")
decode_S_range, decode_L_range = decode_surface_grids()
decode_S_vals = np.arange(decode_S_range[0], decode_S_range[1], decode_S_range[2], dtype=float)
decode_L_vals = np.arange(
    decode_L_range[0],
    decode_L_range[1] + decode_L_range[2],
    decode_L_range[2],
    dtype=float,
)

decode_Sg = decode_S_vals[None, :]
decode_Lg = decode_L_vals[:, None]

decode_T_compute = decode_compute_time(decode_Sg, decode_Lg, cfg, hardware)
decode_T_memory = decode_memory_time(decode_Sg, decode_Lg, cfg, hardware)
decode_T_max = np.maximum(decode_T_compute, decode_T_memory)

decode_eps = 0.05
decode_den = np.maximum(decode_T_compute, decode_T_memory)
decode_boundary = np.abs(decode_T_compute - decode_T_memory) / np.where(decode_den == 0, 1, decode_den) <= decode_eps
di, dj = np.where(decode_boundary)
decode_boundary_L = decode_L_vals[di]
decode_boundary_S = decode_S_vals[dj]
decode_boundary_Z = decode_T_compute[di, dj]

make_surface(
    decode_S_vals,
    decode_L_vals,
    decode_T_compute,
    decode_boundary_S,
    decode_boundary_L,
    decode_boundary_Z,
    "Decode — T_compute(S,L) per token",
    "Time per token (s)",
)

make_surface(
    decode_S_vals,
    decode_L_vals,
    decode_T_memory,
    decode_boundary_S,
    decode_boundary_L,
    decode_boundary_Z,
    "Decode — T_memory(S,L) per token",
    "Time per token (s)",
)

make_surface(
    decode_S_vals,
    decode_L_vals,
    decode_T_max,
    decode_boundary_S,
    decode_boundary_L,
    decode_boundary_Z,
    "Decode — max(T) per token",
    "Time per token (s)",
)
