# Day 3 — Inference Physics (NIM-first)

**Goal:** Build intuition for throughput, tail latency, queueing, and why batching/concurrency matter — using NIM as the serving target.

**Bonus (if you have an H200):** we’ll run a few “hardware physics experiments” so engineers can *see* why VRAM, bandwidth, and KV cache dominate real-world inference.

**Target:** NIM gateway (OpenAI-compatible) at `NIM_BASE_URL` (default `http://localhost:8000`).

**Outputs (what you’ll see):**
- GPU “engine check” + simple bandwidth sanity test
- KV-cache concurrency cliff visualization
- Latency distribution (p50/p95/p99)
- Latency vs concurrency curves (the p95 cliff)
- Throughput vs concurrency curve
- Latency vs token budget curve

**Timebox:** 60–90 minutes.


In [None]:
# Setup + preflight

import os
import sys
import time
import math
from dataclasses import dataclass
from typing import Any
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import pandas as pd
import plotly.express as px
import requests

import ipywidgets as widgets
from IPython.display import display, clear_output

NIM_BASE_URL = os.environ.get("NIM_BASE_URL", "http://localhost:8000").rstrip("/")
NIM_CHAT_PATH = os.environ.get("NIM_CHAT_PATH", "/v1/chat/completions")
NIM_GEN_MODEL = os.environ.get("NIM_GEN_MODEL", "meta/llama-3.1-8b-instruct")

px.defaults.template = "plotly_white"

print("sys.executable:", sys.executable)
print("NIM_BASE_URL:", NIM_BASE_URL)
print("NIM_CHAT_PATH:", NIM_CHAT_PATH)
print("NIM_GEN_MODEL:", NIM_GEN_MODEL)



## Cell 1: The engine check (know your metal)

If you’re on an H200, you’re holding a rare card. Let’s prove what we’re running on.

Key idea: **inference is often memory-bandwidth-bound**, not FLOPs-bound.

We’ll:
- confirm GPU model + total VRAM via NVML
- show a quick “back-of-the-envelope” bandwidth comparison (A100 vs H100 vs H200)
- run a tiny GPU memory-copy benchmark to sanity-check that we can move data fast


In [None]:
# NVML-based GPU identity + VRAM

from pynvml import (
    nvmlInit,
    nvmlShutdown,
    nvmlDeviceGetCount,
    nvmlDeviceGetHandleByIndex,
    nvmlDeviceGetName,
    nvmlDeviceGetMemoryInfo,
    nvmlDeviceGetDriverVersion,
)


def _bytes_gib(x: float) -> float:
    return float(x) / (1024.0**3)


def gpu_engine_check() -> dict:
    nvmlInit()
    try:
        n = nvmlDeviceGetCount()
        if n < 1:
            raise RuntimeError("No NVIDIA GPUs visible to NVML")

        h = nvmlDeviceGetHandleByIndex(0)
        name = nvmlDeviceGetName(h).decode("utf-8", errors="ignore")
        info = nvmlDeviceGetMemoryInfo(h)
        drv = nvmlDeviceGetDriverVersion().decode("utf-8", errors="ignore")

        out = {
            "gpu_name": name,
            "driver": drv,
            "vram_total_gib": _bytes_gib(info.total),
            "vram_used_gib": _bytes_gib(info.used),
            "vram_free_gib": _bytes_gib(info.free),
        }
        return out
    finally:
        try:
            nvmlShutdown()
        except Exception:
            pass


info = gpu_engine_check()
print("GPU detected:", info["gpu_name"])
print("Driver:", info["driver"])
print(f"Total VRAM: {info['vram_total_gib']:.2f} GiB")
print(f"Used VRAM:  {info['vram_used_gib']:.2f} GiB")

print("\n--- Bandwidth context (spec-level, for intuition) ---")
print("A100 (80GB HBM2e):  ~2.0 TB/s")
print("H100 (80GB HBM3):   ~3.3 TB/s")
print("H200 (141GB HBM3e): ~4.8 TB/s")
print("\nSpeaker note: VRAM size is nice; bandwidth is the real villain/hero for inference.")


In [None]:
# A tiny memory-bandwidth sanity check (GPU memcpy benchmark)
# Caveat: microbenchmarks lie. But they lie consistently.

import torch


def memcpy_gbps(*, mib: int = 2048, iters: int = 50) -> float:
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA not available")

    n_bytes = int(mib) * 1024 * 1024
    n = n_bytes // 4  # float32

    a = torch.empty(n, device="cuda", dtype=torch.float32)
    b = torch.empty_like(a)

    # Warmup
    for _ in range(5):
        b.copy_(a)
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(int(iters)):
        b.copy_(a)
    torch.cuda.synchronize()
    dt = time.perf_counter() - t0

    total = n_bytes * int(iters)
    gbps = (total / dt) / 1e9
    return float(gbps)


try:
    gbps = memcpy_gbps(mib=1024, iters=80)
    print(f"Approx device copy throughput: {gbps:.1f} GB/s")
    print("Interpretation: higher is better; compare runs on the same machine.")
except Exception as e:
    print("Memcpy benchmark skipped:", type(e).__name__, str(e)[:200])


In [None]:
# Live VRAM watcher (sample NVML + plot). Useful for demos and debugging.

from pynvml import (
    nvmlInit,
    nvmlShutdown,
    nvmlDeviceGetHandleByIndex,
    nvmlDeviceGetMemoryInfo,
    nvmlDeviceGetUtilizationRates,
    nvmlDeviceGetPowerUsage,
    nvmlDeviceGetTemperature,
    NVML_TEMPERATURE_GPU,
)


def nvml_snapshot() -> dict[str, float]:
    nvmlInit()
    try:
        h = nvmlDeviceGetHandleByIndex(0)
        mem = nvmlDeviceGetMemoryInfo(h)
        util = nvmlDeviceGetUtilizationRates(h)
        pwr_w = float(nvmlDeviceGetPowerUsage(h)) / 1000.0
        temp_c = float(nvmlDeviceGetTemperature(h, NVML_TEMPERATURE_GPU))
        return {
            "vram_used_gib": _bytes_gib(mem.used),
            "vram_total_gib": _bytes_gib(mem.total),
            "gpu_util_pct": float(util.gpu),
            "mem_util_pct": float(util.memory),
            "power_w": float(pwr_w),
            "temp_c": float(temp_c),
        }
    finally:
        try:
            nvmlShutdown()
        except Exception:
            pass


def watch_gpu(*, seconds: float = 10.0, interval_s: float = 0.25) -> pd.DataFrame:
    rows = []
    t0 = time.perf_counter()
    while True:
        t = time.perf_counter() - t0
        if t > float(seconds):
            break
        try:
            snap = nvml_snapshot()
        except Exception as e:
            snap = {"error": 1.0}
            snap["vram_used_gib"] = float("nan")
            snap["gpu_util_pct"] = float("nan")
            snap["power_w"] = float("nan")
            snap["temp_c"] = float("nan")
        snap["t_s"] = float(t)
        rows.append(snap)
        time.sleep(float(interval_s))

    return pd.DataFrame(rows)


# Quick demo: a short watch window (should be mostly idle).
df_watch = watch_gpu(seconds=3.0, interval_s=0.25)
display(df_watch.tail(5))

if not df_watch.empty and "vram_used_gib" in df_watch.columns:
    fig = px.line(df_watch, x="t_s", y=["vram_used_gib", "gpu_util_pct", "power_w"], title="GPU watch (sampled via NVML)")
    fig.show()


## Cell 2: The KV cache physics experiment (why concurrency eats VRAM)

This is the most important “inference physics” mental model:

- Model weights are mostly fixed.
- The **KV cache** grows with:
  - number of concurrent sequences (users)
  - context length (prefill tokens)
  - number of layers
  - hidden size
  - precision (FP16 vs FP8)

If you’ve ever wondered why a serving system falls off a cliff at p95, KV cache is often holding the cliff sign.


In [None]:
# KV cache VRAM cliff: interactive visualization


def kv_cache_gib(*, batch: int, seq_len: int, layers: int, hidden: int, bytes_per: int) -> float:
    # Formula: 2 (K + V) * layers * hidden * seq_len * batch * bytes_per
    cache_per_token = 2 * int(layers) * int(hidden) * int(bytes_per)
    total_bytes = int(batch) * int(seq_len) * int(cache_per_token)
    return float(total_bytes) / (1024.0**3)


w_users = widgets.IntSlider(value=20, min=1, max=200, step=1, description="Concurrent users", style={"description_width": "initial"})
w_seq = widgets.IntSlider(value=8192, min=512, max=16384, step=512, description="Context length (tokens)", style={"description_width": "initial"})
w_layers = widgets.IntSlider(value=80, min=12, max=120, step=4, description="Layers", style={"description_width": "initial"})
w_hidden = widgets.IntSlider(value=8192, min=1024, max=16384, step=1024, description="Hidden dim", style={"description_width": "initial"})
w_bytes = widgets.Dropdown(options=[("FP16 (2 bytes)", 2), ("FP8 (1 byte)", 1)], value=2, description="KV precision", style={"description_width": "initial"})

w_weights = widgets.FloatSlider(value=130.0, min=1.0, max=200.0, step=1.0, description="Model weights (GiB)", style={"description_width": "initial"})

w_line_a100 = widgets.FloatSlider(value=80.0, min=40.0, max=160.0, step=1.0, description="Line: A100/H100 VRAM (GiB)", style={"description_width": "initial"})
w_line_h200 = widgets.FloatSlider(value=141.0, min=80.0, max=200.0, step=1.0, description="Line: H200 VRAM (GiB)", style={"description_width": "initial"})

out_kv = widgets.Output()


def _render_kv(*_):
    with out_kv:
        clear_output(wait=True)

        users_max = int(w_users.value)
        xs = list(range(1, users_max + 1))

        v_kv = [kv_cache_gib(batch=u, seq_len=int(w_seq.value), layers=int(w_layers.value), hidden=int(w_hidden.value), bytes_per=int(w_bytes.value)) for u in xs]
        v_total = [float(w_weights.value) + v for v in v_kv]

        df = pd.DataFrame({"users": xs, "kv_cache_gib": v_kv, "total_vram_gib": v_total})

        fig = px.line(df, x="users", y="total_vram_gib", title=f"KV cache cliff (seq_len={int(w_seq.value)} tokens)")
        fig.add_hline(y=float(w_line_a100.value), line_dash="dash", annotation_text="~80 GiB line")
        fig.add_hline(y=float(w_line_h200.value), line_dash="dash", annotation_text="~141 GiB line")
        fig.update_layout(xaxis_title="concurrent users", yaxis_title="GiB (weights + KV cache)")
        fig.show()

        # Where do we crash?
        def first_over(limit: float):
            over = df[df["total_vram_gib"] > float(limit)]
            if over.empty:
                return None
            return int(over.iloc[0]["users"])

        crash_80 = first_over(float(w_line_a100.value))
        crash_141 = first_over(float(w_line_h200.value))

        print("=== Back-of-the-envelope ===")
        print(f"Weights: {float(w_weights.value):.1f} GiB")
        print(f"KV precision bytes: {int(w_bytes.value)}")
        print(f"KV at 1 user: {float(v_kv[0]):.2f} GiB")
        print(f"KV at {users_max} users: {float(v_kv[-1]):.2f} GiB")
        print("")
        print("Crash estimate (first user count over the line):")
        print("- ~80 GiB line:", crash_80)
        print("- ~141 GiB line:", crash_141)

        print("\nSpeaker note:")
        print("- This is why long context + concurrency is the real VRAM tax.")
        print("- Batching is good until it becomes KV cache debt.")


for w in [w_users, w_seq, w_layers, w_hidden, w_bytes, w_weights, w_line_a100, w_line_h200]:
    w.observe(_render_kv, names="value")

display(widgets.VBox([
    widgets.HBox([w_users, w_seq]),
    widgets.HBox([w_layers, w_hidden, w_bytes]),
    widgets.HBox([w_weights, w_line_a100, w_line_h200]),
    out_kv,
]))

_render_kv()


In [None]:
# The FP8 "magic trick" (capacity math + optional real run)
# We keep this safe: if vLLM isn't installed (or models aren't available), we still teach the concept.


def estimate_weight_gib(*, params_b: float, bytes_per: int) -> float:
    # Back-of-the-envelope: params * bytes.
    total_bytes = float(params_b) * 1e9 * float(bytes_per)
    return total_bytes / (1024.0**3)


print("=== FP16 vs FP8 memory intuition (very rough) ===")

# Example: 70B params (order-of-magnitude). Actual weights depend on architecture + overhead.
params_b = 70.0
w_fp16 = estimate_weight_gib(params_b=params_b, bytes_per=2)
w_fp8 = estimate_weight_gib(params_b=params_b, bytes_per=1)

print(f"Model size (params ~{params_b:.0f}B):")
print(f"- FP16 weights ~ {w_fp16:.1f} GiB")
print(f"- FP8  weights ~ {w_fp8:.1f} GiB")
print("\nTakeaway: FP8 can roughly halve weight memory (and often improves throughput on modern tensor cores).")

print("\nOptional: if you have vLLM installed, you can try a real FP16 vs FP8 load/run.")
print("(This is optional because large models require access and disk/cache; the math demo above always works.)")

try:
    import vllm  # type: ignore

    print("vLLM detected:", getattr(vllm, "__version__", "?"))
    print("If you want to run vLLM here, consider a smaller model unless you're sure the 70B weights are cached.")
except Exception as e:
    print("vLLM not available (ok):", type(e).__name__, str(e)[:120])


In [None]:
# Throughput test (NIM): approximate tokens/sec
# Note: OpenAI-compatible servers often return usage tokens, but not always. We'll use usage if present, else a crude proxy.


def nim_chat_raw(*, prompt: str, max_tokens: int = 128, temperature: float = 0.0, timeout_s: float = 60.0) -> tuple[dict, float]:
    url = f"{NIM_BASE_URL}{NIM_CHAT_PATH}"
    payload = {
        "model": NIM_GEN_MODEL,
        "messages": [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ],
        "max_tokens": int(max_tokens),
        "temperature": float(temperature),
    }
    t0 = time.perf_counter()
    r = requests.post(url, headers={"Content-Type": "application/json"}, json=payload, timeout=float(timeout_s))
    dt = time.perf_counter() - t0
    r.raise_for_status()
    return r.json(), float(dt)


def _count_tokens_from_usage(j: dict) -> int | None:
    u = j.get("usage") or {}
    # OpenAI-ish: prompt_tokens / completion_tokens / total_tokens
    ct = u.get("completion_tokens")
    if isinstance(ct, (int, float)):
        return int(ct)
    return None


PROMPT = "Explain quantum physics in one sentence."  # intentionally small
N_REQ = 40
MAX_TOKENS = 64

lat = []
out_tokens = []

print(f"Running {N_REQ} requests...")
for _ in range(N_REQ):
    j, dt = nim_chat_raw(prompt=PROMPT, max_tokens=MAX_TOKENS, temperature=0.0, timeout_s=60)
    lat.append(float(dt))
    t = _count_tokens_from_usage(j)
    if t is None:
        # crude proxy: words (not real tokens, but directionally useful)
        txt = (((j.get("choices") or [{}])[0].get("message") or {}).get("content") or "")
        t = max(1, len(str(txt).split()))
    out_tokens.append(int(t))

sec = float(sum(lat))
tokens = int(sum(out_tokens))
print("Total time (s):", f"{sec:.2f}")
print("Total completion tokens (approx):", tokens)
print("Throughput (tokens/s):", f"{tokens/sec:.1f}")

# Show p50/p95 just for feel
lat_sorted = sorted(lat)
print("p50 latency (s):", f"{lat_sorted[int(0.5*len(lat_sorted))]:.3f}")
print("p95 latency (s):", f"{lat_sorted[int(0.95*len(lat_sorted))-1]:.3f}")

# Optional: watch GPU while doing a short burst
print("\nGPU watch during a short burst (if NVML works):")
df_burst = watch_gpu(seconds=2.5, interval_s=0.25)
if not df_burst.empty:
    px.line(df_burst, x="t_s", y=["vram_used_gib", "gpu_util_pct", "power_w"], title="GPU during burst").show()


In [None]:
# Dollars & cents (toy ROI calculator)
# This is deliberately hand-wavy. The purpose is to force the conversation: "cost per concurrent user".

w_users_a100 = widgets.IntSlider(value=5, min=1, max=100, step=1, description="Users on A100/H100", style={"description_width": "initial"})
w_users_h200 = widgets.IntSlider(value=20, min=1, max=200, step=1, description="Users on H200", style={"description_width": "initial"})

w_price_a100 = widgets.IntSlider(value=12000, min=2000, max=80000, step=500, description="Price A100/H100 ($)", style={"description_width": "initial"})
w_price_h200 = widgets.IntSlider(value=25000, min=5000, max=120000, step=500, description="Price H200 ($)", style={"description_width": "initial"})

out_roi = widgets.Output()


def _render_roi(*_):
    with out_roi:
        clear_output(wait=True)

        a = float(w_price_a100.value) / max(1, int(w_users_a100.value))
        h = float(w_price_h200.value) / max(1, int(w_users_h200.value))

        print("=== Cost per concurrent user (toy) ===")
        print(f"A100/H100: ${a:,.0f} per user")
        print(f"H200:      ${h:,.0f} per user")

        if h < a:
            print("\nTakeaway: the expensive GPU can be cheaper per user if it buys enough concurrency.")
        else:
            print("\nTakeaway: if concurrency doesn't increase, you're just buying a nicer space heater.")


for w in [w_users_a100, w_users_h200, w_price_a100, w_price_h200]:
    w.observe(_render_roi, names="value")

display(widgets.VBox([widgets.HBox([w_users_a100, w_users_h200]), widgets.HBox([w_price_a100, w_price_h200]), out_roi]))

_render_roi()


In [None]:
def nim_chat_once(*, prompt: str, max_tokens: int = 32, temperature: float = 0.0, timeout_s: float = 60.0) -> tuple[str, float]:
    """One OpenAI-style chat request to NIM; returns (text, latency_s)."""
    url = f"{NIM_BASE_URL}{NIM_CHAT_PATH}"
    payload = {
        "model": NIM_GEN_MODEL,
        "messages": [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ],
        "max_tokens": int(max_tokens),
        "temperature": float(temperature),
    }

    t0 = time.perf_counter()
    r = requests.post(url, headers={"Content-Type": "application/json"}, json=payload, timeout=float(timeout_s))
    dt = time.perf_counter() - t0
    r.raise_for_status()
    j = r.json()

    choices = j.get("choices") or []
    msg = (choices[0].get("message") if choices else {}) or {}
    content = msg.get("content")
    if content is None:
        content = (choices[0].get("text") if choices else "")

    return str(content or "").strip(), float(dt)


def nim_preflight() -> bool:
    print("\n=== NIM preflight ===")
    try:
        txt, dt = nim_chat_once(prompt="Reply with only: OK", max_tokens=4, temperature=0.0, timeout_s=20)
        print(f"✅ NIM reachable: {dt:.3f}s | sample={txt!r}")
        return True
    except Exception as e:
        print(f"❌ NIM not reachable at {NIM_BASE_URL}: {type(e).__name__}: {str(e)[:200]}")
        print("\nTo start local NIMs:")
        print("  cd fico")
        print("  export NGC_API_KEY=...   # needed to pull nvcr.io images")
        print("  ./scripts/start_nims.sh")
        print("\nThen re-run this cell.")
        return False


if not nim_preflight():
    raise RuntimeError("NIM preflight failed")



## Mental model (pretty short)

End-to-end latency is roughly:

\[
T = T_{queue} + T_{net} + T_{compute}
\]

- Under light load, **p50** is mostly \(T_{net} + T_{compute}\).
- Under contention, **p95/p99** are dominated by \(T_{queue}\) (waiting).

A useful intuition (Little’s Law):

\[
L = \lambda W
\]

- \(L\): average number of requests in the system
- \(\lambda\): throughput (req/s)
- \(W\): time-in-system (seconds)

As you push concurrency up, you often increase \(L\), which increases \(W\) (latency).

Your goal is to find the **knee**: throughput stops improving, but p95 explodes.




## Baseline: single-request latency distribution

We send the same request repeatedly (sequentially) to estimate the **noise floor** before adding concurrency.


In [None]:
def pct(xs: list[float], p: float) -> float:
    if not xs:
        return float("nan")
    return float(np.percentile(np.array(xs, dtype=np.float64), p))


def summarize_latencies(lat_s: list[float]) -> dict[str, float]:
    return {
        "n": float(len(lat_s)),
        "mean_s": float(np.mean(lat_s)) if lat_s else float("nan"),
        "p50_s": pct(lat_s, 50),
        "p95_s": pct(lat_s, 95),
        "p99_s": pct(lat_s, 99),
        "min_s": float(min(lat_s)) if lat_s else float("nan"),
        "max_s": float(max(lat_s)) if lat_s else float("nan"),
    }


BASE_PROMPT = "Summarize why batching affects throughput in one short paragraph."
BASE_MAX_TOKENS = 96
BASE_TEMP = 0.2



In [None]:
N_BASELINE = 12

lat = []
for _ in range(N_BASELINE):
    _, dt = nim_chat_once(prompt=BASE_PROMPT, max_tokens=BASE_MAX_TOKENS, temperature=BASE_TEMP, timeout_s=60)
    lat.append(float(dt))

display(pd.DataFrame([summarize_latencies(lat)]))

df_lat = pd.DataFrame({"latency_s": lat})

fig_h = px.histogram(df_lat, x="latency_s", nbins=12, title="Baseline latency histogram", labels={"latency_s": "seconds"})
fig_h.show()

fig_b = px.box(df_lat, y="latency_s", title="Baseline latency box plot", labels={"latency_s": "seconds"})
fig_b.show()



## Concurrency sweep: the p95 cliff

We’ll run a small load test by issuing many requests while limiting **in-flight concurrency**.

What to look for:
- Throughput rises, then saturates.
- p95/p99 often blow up near saturation.



In [None]:
@dataclass
class RunResult:
    ok: bool
    latency_s: float
    error: str | None = None


def _sync_one(timeout_s: float) -> RunResult:
    t0 = time.perf_counter()
    try:
        nim_chat_once(prompt=BASE_PROMPT, max_tokens=BASE_MAX_TOKENS, temperature=BASE_TEMP, timeout_s=timeout_s)
        return RunResult(ok=True, latency_s=float(time.perf_counter() - t0), error=None)
    except Exception as e:
        return RunResult(ok=False, latency_s=float(time.perf_counter() - t0), error=f"{type(e).__name__}: {str(e)[:160]}")


def run_load(*, total_requests: int, concurrency: int, timeout_s: float) -> dict[str, Any]:
    """Threadpool-based loadgen: portable and works in notebooks."""
    total_requests = int(total_requests)
    concurrency = int(concurrency)

    t0 = time.perf_counter()
    results: list[RunResult] = []

    with ThreadPoolExecutor(max_workers=concurrency) as ex:
        futs = [ex.submit(_sync_one, float(timeout_s)) for _ in range(total_requests)]
        for f in as_completed(futs):
            results.append(f.result())

    wall_s = float(time.perf_counter() - t0)
    lat_ok = [r.latency_s for r in results if r.ok]
    err = [r for r in results if not r.ok]

    throughput = (len(results) / wall_s) if wall_s > 0 else float("nan")
    err_rate = (len(err) / len(results)) if results else float("nan")

    return {
        "concurrency": concurrency,
        "total_requests": total_requests,
        "timeout_s": float(timeout_s),
        "wall_s": wall_s,
        "throughput_rps": float(throughput),
        "error_rate": float(err_rate),
        "p50_s": pct(lat_ok, 50),
        "p95_s": pct(lat_ok, 95),
        "p99_s": pct(lat_ok, 99),
        "mean_s": float(np.mean(lat_ok)) if lat_ok else float("nan"),
        "ok": int(len(lat_ok)),
        "err": int(len(err)),
        "sample_error": (err[0].error if err else None),
    }


CONCURRENCY_LEVELS = [1, 2, 4, 8, 16]
TOTAL_REQ = 80
TIMEOUT_S = 60.0

rows = [run_load(total_requests=TOTAL_REQ, concurrency=c, timeout_s=TIMEOUT_S) for c in CONCURRENCY_LEVELS]
df_c = pd.DataFrame(rows).sort_values("concurrency")
display(df_c)

fig_lat = px.line(
    df_c,
    x="concurrency",
    y=["p50_s", "p95_s", "p99_s"],
    markers=True,
    title="Latency vs concurrency (p50/p95/p99)",
    labels={"value": "seconds"},
)
fig_lat.show()

fig_tp = px.line(df_c, x="concurrency", y="throughput_rps", markers=True, title="Throughput vs concurrency", labels={"throughput_rps": "req/s"})
fig_tp.show()

fig_err = px.bar(df_c, x="concurrency", y="error_rate", title="Error rate vs concurrency")
fig_err.update_yaxes(range=[0, max(0.05, float(df_c["error_rate"].max()) * 1.2 if len(df_c) else 0.1)])
fig_err.show()



## Token budget: generation dominates

Roughly, generation cost scales with:
- output tokens (`max_tokens`)
- and often input size (prompt length)

We’ll vary both and plot how p95 latency responds.



In [None]:
def make_prompt(chars: int) -> str:
    base = "Summarize the following text in one sentence.\n\n"
    filler = ("lorem ipsum ") * 5000
    return base + filler[: int(chars)]


def measure_p95(*, prompt: str, max_tokens: int, n: int = 6) -> dict[str, Any]:
    lats = []
    for _ in range(int(n)):
        _, dt = nim_chat_once(prompt=prompt, max_tokens=int(max_tokens), temperature=0.2, timeout_s=60)
        lats.append(float(dt))
    return {
        "prompt_chars": int(len(prompt)),
        "max_tokens": int(max_tokens),
        "p95_s": pct(lats, 95),
        "p50_s": pct(lats, 50),
    }


PROMPT_SIZES = [200, 800, 2000]
MAX_TOKENS_GRID = [32, 64, 128, 256]

rows = []
for pc in PROMPT_SIZES:
    p = make_prompt(pc)
    for mt in MAX_TOKENS_GRID:
        rows.append(measure_p95(prompt=p, max_tokens=mt, n=6))

df_tok = pd.DataFrame(rows)
display(df_tok)

fig_tok = px.line(
    df_tok,
    x="max_tokens",
    y="p95_s",
    color=df_tok["prompt_chars"].astype(str),
    markers=True,
    title="p95 latency vs max_tokens (colored by input size)",
    labels={"p95_s": "p95 seconds", "color": "prompt_chars"},
)
fig_tok.show()

hm = df_tok.pivot_table(index="prompt_chars", columns="max_tokens", values="p95_s", aggfunc="mean")
fig_hm = px.imshow(hm, title="Heatmap: p95 latency (seconds)", labels={"x": "max_tokens", "y": "prompt_chars", "color": "p95 seconds"})
fig_hm.show()



## Backpressure + timeouts (the failure mode)

When you overload an inference server:
- p95/p99 get worse (queueing)
- then requests start timing out / failing

A practical rule: **cap concurrency** and set a timeout that matches your SLO.



In [None]:
# Demonstrate overload by lowering timeout and increasing concurrency

OVERLOAD_TIMEOUT_S = 5.0
OVERLOAD_TOTAL_REQ = 60
OVERLOAD_CONCURRENCY = [4, 8, 16, 32]

rows = [run_load(total_requests=OVERLOAD_TOTAL_REQ, concurrency=c, timeout_s=OVERLOAD_TIMEOUT_S) for c in OVERLOAD_CONCURRENCY]
df_over = pd.DataFrame(rows).sort_values("concurrency")
display(df_over)

fig = px.line(
    df_over,
    x="concurrency",
    y=["p95_s", "error_rate"],
    markers=True,
    title=f"Overload demo (timeout={OVERLOAD_TIMEOUT_S}s)",
)
fig.show()



## Practical tuning checklist (NIM / TensorRT-style serving)

- **Start with a baseline**: measure p50/p95 at concurrency=1.
- **Find the knee**: increase concurrency until throughput stops improving.
- **Protect tail latency**:
  - cap concurrency below the knee
  - set timeouts intentionally (match your SLO)
- **Control generation cost**:
  - reduce `max_tokens`
  - keep prompts short; avoid dumping huge context
- **Operationally**:
  - watch error rate
  - avoid blind retries under overload (use backoff)



## Optional appendix: mapping the same ideas to vLLM (concepts only)

Even though we used NIM here, the same “physics” shows up in any serving stack:
- **Batching** improves GPU utilization (higher throughput) but can add queueing delay.
- **KV-cache** makes long generations expensive; output tokens dominate cost.
- **Scheduling** decisions move you along the latency/throughput frontier.

If you later benchmark vLLM, you’ll typically see the same p95 cliff as you approach saturation.

