<a href="https://colab.research.google.com/github/vsayyagari/RouteLLM/blob/main/micro_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import sys
import statistics
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Initiate Model

In [2]:
# MODEL_ID = "meta-llama/Llama-3.2-1B"  # or "meta-llama/Llama-3.2-1B-Instruct"
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
# If required (gated), set HF_TOKEN in your env:
# export HF_TOKEN=...

token = os.environ.get("HF_TOKEN", None)

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=token, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, token=token, dtype=dtype)
model.to(device).eval()

eos_id = tokenizer.eos_token_id



tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

In [3]:
def sample_next_token(logits, temperature=0.8, top_k=50):
    # logits: [batch, vocab]
    if temperature <= 0:
        return torch.argmax(logits, dim=-1, keepdim=True)

    logits = logits / temperature
    print("logits shape", logits.shape)
    if top_k is not None:
        topk_vals, topk_idx = torch.topk(logits, k=min(top_k, logits.size(-1)), dim=-1)
        probs = torch.softmax(topk_vals, dim=-1)
        next_local = torch.multinomial(probs, num_samples=1)              # [batch, 1]
        next_token = topk_idx.gather(-1, next_local)                      # [batch, 1]
        return next_token

    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

In [4]:
@torch.inference_mode()
def generate(prompt, max_new_tokens=80, temperature=0.8, top_k=50):
    # Encode prompt once
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    # First forward pass over the whole prompt
    out = model(input_ids=input_ids, use_cache=True)
    # logits shape: [batch, seq_len, vocab_size]
    # KV cache:
    past = out.past_key_values

    generated = []

    for _ in range(max_new_tokens):
        next_logits = out.logits[:, -1, :]                 # [batch, vocab]
        next_id = sample_next_token(next_logits, temperature, top_k)  # [batch, 1]

        if eos_id is not None and (next_id == eos_id).all():
            break

        generated.append(next_id)

        # Next step: feed ONLY the new token + KV cache
        out = model(input_ids=next_id, past_key_values=past, use_cache=True)
        past = out.past_key_values

    if generated:
        gen_ids = torch.cat(generated, dim=1)              # [batch, new_len]
        return tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    return ""


In [5]:
print(generate("Explain KV cache in one short paragraph:", max_new_tokens=80))

logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
logits shape torch.Size([1, 128256])
l

### Apply repetition penality

### Apply top-p, top-k filtering

In [6]:

def apply_repetition_penalty_(logits: torch.Tensor, generated_ids: torch.Tensor, penalty: float):
    """
    logits: [batch, vocab]
    generated_ids: [batch, total_generated] token ids seen so far (can include prompt if you want)
    penalty > 1.0 discourages repeats
    """
    if penalty is None or penalty <= 1.0 or generated_ids.numel() == 0:
        return logits

    # Apply penalty per token id that has appeared (common simple approach)
    # For each token that was generated, modify its logit.
    for b in range(logits.size(0)):
        seen = torch.unique(generated_ids[b])
        # If logit > 0 => divide by penalty; else multiply by penalty (matches common implementations)
        l = logits[b, seen]
        logits[b, seen] = torch.where(l > 0, l / penalty, l * penalty)
    return logits

def top_k_top_p_filtering(logits: torch.Tensor, top_k: int = 0, top_p: float = 1.0):
    """
    logits: [batch, vocab]
    returns logits with filtered values set to -inf
    """
    top_k = int(top_k or 0)
    top_p = float(top_p if top_p is not None else 1.0)

    # Top-K
    if top_k > 0:
        k = min(top_k, logits.size(-1))
        kth_vals = torch.topk(logits, k=k, dim=-1).values[:, -1].unsqueeze(-1)
        logits = torch.where(logits < kth_vals, torch.full_like(logits, float("-inf")), logits)

    # Top-P (nucleus)
    if top_p < 1.0:
        sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
        probs = torch.softmax(sorted_logits, dim=-1)
        cumprobs = torch.cumsum(probs, dim=-1)

        # Mask tokens with cumulative prob above top_p
        mask = cumprobs > top_p
        # Keep at least 1 token
        mask[:, 0] = False

        sorted_logits = torch.where(mask, torch.full_like(sorted_logits, float("-inf")), sorted_logits)
        # Scatter back
        logits = torch.full_like(logits, float("-inf"))
        logits.scatter_(dim=-1, index=sorted_idx, src=sorted_logits)

    return logits

def sample_from_logits(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
):
    """
    logits: [batch, vocab]
    returns: [batch, 1]
    """
    if temperature is None:
        temperature = 1.0

    if temperature <= 0:
        return torch.argmax(logits, dim=-1, keepdim=True)

    logits = logits / temperature
    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

    probs = torch.softmax(logits, dim=-1)

    # If everything got filtered (can happen with extreme settings), fall back to argmax
    if torch.isnan(probs).any() or (probs.sum(dim=-1) == 0).any():
        return torch.argmax(logits, dim=-1, keepdim=True)

    return torch.multinomial(probs, num_samples=1)

@torch.inference_mode()
def generate_stream(
    prompt: str,
    max_new_tokens: int = 128,
    temperature: float = 0.8,
    top_k: int = 40,
    top_p: float = 0.95,
    repetition_penalty: float = 1.1,
    stop_on_eos: bool = True,
    print_stream: bool = True,
):
    """
    Streams decoded text to stdout (optional) and returns the full generated suffix.
    """
    # Encode prompt
    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(device)

    # First pass on the whole prompt
    out = model(input_ids=input_ids, use_cache=True)
    past = out.past_key_values

    # Track generated token IDs (for repetition penalty)
    generated_ids = torch.empty((input_ids.size(0), 0), dtype=torch.long, device=device)

    pieces = []

    for _ in range(max_new_tokens):
        next_logits = out.logits[:, -1, :]  # [batch, vocab]

        # Apply repetition penalty based on what we've generated so far
        next_logits = apply_repetition_penalty_(next_logits, generated_ids, repetition_penalty)

        next_id = sample_from_logits(
            next_logits,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )  # [batch, 1]

        if stop_on_eos and eos_id is not None and (next_id == eos_id).all():
            break

        generated_ids = torch.cat([generated_ids, next_id], dim=1)

        # Decode just the new token for streaming
        token_text = tokenizer.decode(next_id[0], skip_special_tokens=True)
        pieces.append(token_text)
        if print_stream:
            sys.stdout.write(token_text)
            sys.stdout.flush()

        # Next step uses KV-cache: feed only the new token
        out = model(input_ids=next_id, past_key_values=past, use_cache=True)
        past = out.past_key_values

    if print_stream:
        sys.stdout.write("\n")

    return "".join(pieces)

# --- Example ---
prompt = "Explain top-p (nucleus) sampling vs top-k in 5-7 sentences:\n"
suffix = generate_stream(
    prompt,
    max_new_tokens=140,
    temperature=0.8,
    top_k=40,
    top_p=0.95,
    repetition_penalty=1.1,
    print_stream=True,
)


Top-p sampling is a type of random sampling where the number of samples selected from each stratum is proportional to its size, but not necessarily at the top k. Top-k sampling, on the other hand, selects samples based on their position in descending order (e.g., 1st, 2nd, etc.). The main difference between the two approaches is that top-p sampling focuses on the proportionality of sample sizes, while top-k sampling prioritizes the ranking of observations by their likelihood or probability. This means that top-p sampling can be more efficient for estimating probabilities or proportions, while top-k sampling may be better suited for hypothesis testing and confidence intervals.
Overall, the choice between top-p


# Benchmark

In [7]:
def _decode_loop_with_cache(input_ids, max_new_tokens: int):
    """
    KV-cache decode: prefill once, then feed 1 token per step with past_key_values.
    Returns: prefill_s, decode_s, tokens_generated
    """
    t0 = time.perf_counter()
    out = model(input_ids=input_ids, use_cache=True)
    past = out.past_key_values
    t1 = time.perf_counter()

    generated = 0
    next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)  # greedy for benchmarking

    for _ in range(max_new_tokens):
        if eos_id is not None and (next_id == eos_id).all():
            break
        out = model(input_ids=next_id, past_key_values=past, use_cache=True)
        past = out.past_key_values
        next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
        generated += 1

    t2 = time.perf_counter()
    return (t1 - t0), (t2 - t1), generated


def _decode_loop_no_cache(full_ids, max_new_tokens: int):
    """
    No KV-cache: each step re-runs the entire sequence so far (slow).
    Returns: prefill_s, decode_s, tokens_generated
    """
    # "prefill" is just the first full forward
    t0 = time.perf_counter()
    out = model(input_ids=full_ids, use_cache=False)
    t1 = time.perf_counter()

    generated = 0
    next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)

    for _ in range(max_new_tokens):
        if eos_id is not None and (next_id == eos_id).all():
            break
        full_ids = torch.cat([full_ids, next_id], dim=1)
        out = model(input_ids=full_ids, use_cache=False)
        next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
        generated += 1

    t2 = time.perf_counter()
    return (t1 - t0), (t2 - t1), generated


@torch.inference_mode()
def benchmark(
    prompt: str,
    max_new_tokens: int = 64,
    runs: int = 5,
    warmup: int = 1,
    use_cache: bool = True,
):
    """
    Prints timing stats for CPU:
      - prefill seconds
      - decode seconds
      - decode tokens/sec
      - total tokens/sec (prompt tokens + generated tokens) / total time
    """
    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(device)
    prompt_tokens = input_ids.size(1)

    # Warmup
    for _ in range(warmup):
        if use_cache:
            _decode_loop_with_cache(input_ids, max_new_tokens=8)
        else:
            _decode_loop_no_cache(input_ids.clone(), max_new_tokens=8)

    prefill_s_list = []
    decode_s_list = []
    gen_list = []

    for _ in range(runs):
        if use_cache:
            prefill_s, decode_s, gen = _decode_loop_with_cache(input_ids, max_new_tokens)
        else:
            prefill_s, decode_s, gen = _decode_loop_no_cache(input_ids.clone(), max_new_tokens)

        prefill_s_list.append(prefill_s)
        decode_s_list.append(decode_s)
        gen_list.append(gen)

    # Aggregate
    prefill_med = statistics.median(prefill_s_list)
    decode_med = statistics.median(decode_s_list)
    gen_med = int(statistics.median(gen_list))

    decode_tps = (gen_med / decode_med) if decode_med > 0 else float("inf")
    total_tokens = prompt_tokens + gen_med
    total_time = prefill_med + decode_med
    total_tps = (total_tokens / total_time) if total_time > 0 else float("inf")

    mode = "KV-cache ON" if use_cache else "KV-cache OFF"
    print(f"\n=== Benchmark ({mode}) ===")
    print(f"Prompt tokens:       {prompt_tokens}")
    print(f"Generated tokens:    {gen_med} (target {max_new_tokens})")
    print(f"Prefill time (med):  {prefill_med:.4f} s")
    print(f"Decode time (med):   {decode_med:.4f} s")
    print(f"Decode tok/s:        {decode_tps:.2f}")
    print(f"Total tok/s:         {total_tps:.2f}  (prompt+gen over prefill+decode)")
    print(f"Runs: {runs} (warmup {warmup})")


# ---- Example usage ----
prompt = "Write a short explanation of KV cache in LLM inference.\n"
benchmark(prompt, max_new_tokens=64, runs=5, warmup=1, use_cache=True)
benchmark(prompt, max_new_tokens=64, runs=3, warmup=1, use_cache=False)  # fewer runs: it's very slow



=== Benchmark (KV-cache ON) ===
Prompt tokens:       13
Generated tokens:    64 (target 64)
Prefill time (med):  0.0206 s
Decode time (med):   1.3159 s
Decode tok/s:        48.64
Total tok/s:         57.61  (prompt+gen over prefill+decode)
Runs: 5 (warmup 1)

=== Benchmark (KV-cache OFF) ===
Prompt tokens:       13
Generated tokens:    64 (target 64)
Prefill time (med):  0.0241 s
Decode time (med):   1.7970 s
Decode tok/s:        35.61
Total tok/s:         42.28  (prompt+gen over prefill+decode)
Runs: 3 (warmup 1)


# Benchmark with optional torch.compile

In [8]:

def maybe_compile(m, use_compile: bool, *, mode="reduce-overhead", fullgraph=False):
    """
    mode: "reduce-overhead" (often good for decode), or "max-autotune" (can be slower to compile)
    fullgraph=False is safer with Transformers.
    """
    if not use_compile:
        return m
    if not hasattr(torch, "compile"):
        raise RuntimeError("torch.compile not available. Install PyTorch 2.x.")
    return torch.compile(m, mode=mode, fullgraph=fullgraph)

@torch.inference_mode()
def _decode_loop_with_cache(m, input_ids, max_new_tokens: int):
    """
    KV-cache decode: prefill once, then feed 1 token per step with past_key_values.
    Returns: prefill_s, decode_s, tokens_generated
    """
    t0 = time.perf_counter()
    out = m(input_ids=input_ids, use_cache=True)
    past = out.past_key_values
    t1 = time.perf_counter()

    generated = 0
    next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)  # greedy for benchmarking

    for _ in range(max_new_tokens):
        if eos_id is not None and (next_id == eos_id).all():
            break
        out = m(input_ids=next_id, past_key_values=past, use_cache=True)
        past = out.past_key_values
        next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
        generated += 1

    t2 = time.perf_counter()
    return (t1 - t0), (t2 - t1), generated

@torch.inference_mode()
def _decode_loop_no_cache(m, full_ids, max_new_tokens: int):
    """
    No KV-cache: each step re-runs the entire sequence so far (slow).
    Returns: prefill_s, decode_s, tokens_generated
    """
    t0 = time.perf_counter()
    out = m(input_ids=full_ids, use_cache=False)
    t1 = time.perf_counter()

    generated = 0
    next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)

    for _ in range(max_new_tokens):
        if eos_id is not None and (next_id == eos_id).all():
            break
        full_ids = torch.cat([full_ids, next_id], dim=1)
        out = m(input_ids=full_ids, use_cache=False)
        next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
        generated += 1

    t2 = time.perf_counter()
    return (t1 - t0), (t2 - t1), generated

def benchmark(
    prompt: str,
    max_new_tokens: int = 64,
    runs: int = 5,
    warmup: int = 1,
    use_cache: bool = True,
    use_compile: bool = False,
    compile_mode: str = "reduce-overhead",
):
    """
    Benchmarks median prefill/decode times.
    Warmup runs include compilation overhead so measured runs don't.
    """
    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(device)
    prompt_tokens = input_ids.size(1)

    m = maybe_compile(model, use_compile, mode=compile_mode, fullgraph=False)

    # Warmup (triggers compilation if enabled)
    for _ in range(warmup):
        if use_cache:
            _decode_loop_with_cache(m, input_ids, max_new_tokens=8)
        else:
            _decode_loop_no_cache(m, input_ids.clone(), max_new_tokens=8)

    prefill_s_list, decode_s_list, gen_list = [], [], []

    for _ in range(runs):
        if use_cache:
            prefill_s, decode_s, gen = _decode_loop_with_cache(m, input_ids, max_new_tokens)
        else:
            prefill_s, decode_s, gen = _decode_loop_no_cache(m, input_ids.clone(), max_new_tokens)

        prefill_s_list.append(prefill_s)
        decode_s_list.append(decode_s)
        gen_list.append(gen)

    prefill_med = statistics.median(prefill_s_list)
    decode_med = statistics.median(decode_s_list)
    gen_med = int(statistics.median(gen_list))

    decode_tps = (gen_med / decode_med) if decode_med > 0 else float("inf")
    total_tokens = prompt_tokens + gen_med
    total_time = prefill_med + decode_med
    total_tps = (total_tokens / total_time) if total_time > 0 else float("inf")

    mode = []
    mode.append("KV-cache ON" if use_cache else "KV-cache OFF")
    mode.append("compile ON" if use_compile else "compile OFF")
    if use_compile:
        mode.append(f"({compile_mode})")

    print(f"\n=== Benchmark [{' | '.join(mode)}] ===")
    print(f"Prompt tokens:       {prompt_tokens}")
    print(f"Generated tokens:    {gen_med} (target {max_new_tokens})")
    print(f"Prefill time (med):  {prefill_med:.4f} s")
    print(f"Decode time (med):   {decode_med:.4f} s")
    print(f"Decode tok/s:        {decode_tps:.2f}")
    print(f"Total tok/s:         {total_tps:.2f}  (prompt+gen over prefill+decode)")
    print(f"Runs: {runs} (warmup {warmup})")

# ---- Example usage ----
prompt = "Write a short explanation of KV cache in LLM inference.\n"

# Eager baseline
benchmark(prompt, max_new_tokens=64, runs=5, warmup=1, use_cache=True,  use_compile=False)
benchmark(prompt, max_new_tokens=64, runs=3, warmup=1, use_cache=False, use_compile=False)

# torch.compile
# benchmark(prompt, max_new_tokens=64, runs=5, warmup=1, use_cache=True,  use_compile=True, compile_mode="reduce-overhead")
# benchmark(prompt, max_new_tokens=64, runs=3, warmup=1, use_cache=False, use_compile=True, compile_mode="reduce-overhead")
benchmark(prompt, max_new_tokens=64, runs=5, warmup=1, use_cache=True,  use_compile=True, compile_mode="default")
benchmark(prompt, max_new_tokens=64, runs=3, warmup=1, use_cache=False, use_compile=True, compile_mode="default")



=== Benchmark [KV-cache ON | compile OFF] ===
Prompt tokens:       13
Generated tokens:    64 (target 64)
Prefill time (med):  0.0220 s
Decode time (med):   1.3164 s
Decode tok/s:        48.62
Total tok/s:         57.53  (prompt+gen over prefill+decode)
Runs: 5 (warmup 1)

=== Benchmark [KV-cache OFF | compile OFF] ===
Prompt tokens:       13
Generated tokens:    64 (target 64)
Prefill time (med):  0.0236 s
Decode time (med):   2.0557 s
Decode tok/s:        31.13
Total tok/s:         37.03  (prompt+gen over prefill+decode)
Runs: 3 (warmup 1)


  return torch._C._get_cublas_allow_tf32()
W1217 17:18:02.028000 789 torch/_inductor/utils.py:1558] [0/0] Not enough SMs to use max_autotune_gemm mode



=== Benchmark [KV-cache ON | compile ON | (default)] ===
Prompt tokens:       13
Generated tokens:    64 (target 64)
Prefill time (med):  0.0070 s
Decode time (med):   0.8456 s
Decode tok/s:        75.69
Total tok/s:         90.31  (prompt+gen over prefill+decode)
Runs: 5 (warmup 1)

=== Benchmark [KV-cache OFF | compile ON | (default)] ===
Prompt tokens:       13
Generated tokens:    64 (target 64)
Prefill time (med):  0.0063 s
Decode time (med):   1.0086 s
Decode tok/s:        63.45
Total tok/s:         75.87  (prompt+gen over prefill+decode)
Runs: 3 (warmup 1)


# Benchmark harness with backend/mode sweep (CPU/GPU/TPU-ready)

In [13]:


# ---------- device selection ----------
def pick_device():
    # TPU (torch_xla) if available + XLA device present
    try:
        import torch_xla.core.xla_model as xm
        return xm.xla_device(), "xla"
    except Exception:
        pass
    # CUDA if available
    if torch.cuda.is_available():
        return torch.device("cuda"), "cuda"
    # CPU fallback
    return torch.device("cpu"), "cpu"

device, dev_kind = pick_device()
print("Device:", device, "(kind:", dev_kind, ")")

# ---------- dtype ----------
if dev_kind == "cuda":
    dtype = torch.float16
elif dev_kind == "xla":
    # BF16 is commonly a good choice on TPU if supported
    dtype = torch.bfloat16
else:
    dtype = torch.float32

# ---------- sync helpers ----------
def sync():
    if dev_kind == "cuda":
        torch.cuda.synchronize()
    elif dev_kind == "xla":
        import torch_xla.core.xla_model as xm
        xm.mark_step()

# ---------- compilation ----------
def maybe_compile(m, use_compile: bool, backend: str, mode: str):
    if not use_compile:
        return m
    if not hasattr(torch, "compile"):
        raise RuntimeError("torch.compile not available (need PyTorch 2.x).")
    # TPU: PyTorch/XLA integrates with torch.compile via backend='openxla'
    # See PyTorch/XLA torch.compile docs.
    return torch.compile(m, backend=backend, mode=mode, fullgraph=False)

# ---------- decode kernels ----------
@torch.inference_mode()
def run_with_cache(m, input_ids, max_new_tokens: int):
    # prefill
    sync()
    t0 = time.perf_counter()
    out = m(input_ids=input_ids, use_cache=True)
    sync()
    t1 = time.perf_counter()

    past = out.past_key_values
    next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
    gen = 0

    # decode
    sync()
    t2 = time.perf_counter()
    for _ in range(max_new_tokens):
        if eos_id is not None and (next_id == eos_id).all():
            break
        out = m(input_ids=next_id, past_key_values=past, use_cache=True)
        past = out.past_key_values
        next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
        gen += 1
    sync()
    t3 = time.perf_counter()

    return (t1 - t0), (t3 - t2), gen

@torch.inference_mode()
def run_no_cache(m, full_ids, max_new_tokens: int):
    # prefill == first full forward
    sync()
    t0 = time.perf_counter()
    out = m(input_ids=full_ids, use_cache=False)
    sync()
    t1 = time.perf_counter()

    next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
    gen = 0

    # decode (recompute full sequence each step)
    sync()
    t2 = time.perf_counter()
    for _ in range(max_new_tokens):
        if eos_id is not None and (next_id == eos_id).all():
            break
        full_ids = torch.cat([full_ids, next_id], dim=1)
        out = m(input_ids=full_ids, use_cache=False)
        next_id = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
        gen += 1
    sync()
    t3 = time.perf_counter()

    return (t1 - t0), (t3 - t2), gen

def safe_list_backends():
    # This is an internal API; may vary by version, so we keep it optional.
    try:
        import torch._dynamo as dynamo
        return sorted(list(dynamo.list_backends()))
    except Exception:
        return []

def bench_one(prompt: str, max_new_tokens=64, runs=5, warmup=1,
              use_compile=False, backend="inductor", mode="default", use_cache=True):
    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(device)
    prompt_toks = input_ids.size(1)

    m = maybe_compile(model, use_compile=use_compile, backend=backend, mode=mode)

    # warmup: include compilation + caches
    for _ in range(warmup):
        if use_cache:
            run_with_cache(m, input_ids, max_new_tokens=8)
        else:
            run_no_cache(m, input_ids.clone(), max_new_tokens=8)

    prefill_s, decode_s, gens = [], [], []
    for _ in range(runs):
        if use_cache:
            p, d, g = run_with_cache(m, input_ids, max_new_tokens=max_new_tokens)
        else:
            p, d, g = run_no_cache(m, input_ids.clone(), max_new_tokens=max_new_tokens)
        prefill_s.append(p); decode_s.append(d); gens.append(g)

    p_med = statistics.median(prefill_s)
    d_med = statistics.median(decode_s)
    g_med = int(statistics.median(gens))

    decode_tps = g_med / d_med if d_med > 0 else float("inf")
    total_tps = (prompt_toks + g_med) / (p_med + d_med) if (p_med + d_med) > 0 else float("inf")

    return {
        "compile": use_compile,
        "backend": backend,
        "mode": mode,
        "kv_cache": use_cache,
        "prompt_toks": prompt_toks,
        "gen_toks": g_med,
        "prefill_s": p_med,
        "decode_s": d_med,
        "decode_tok_s": decode_tps,
        "total_tok_s": total_tps,
    }

def print_rows(rows):
    # minimal pretty-print without pandas
    cols = ["compile","backend","mode","kv_cache","prompt_toks","gen_toks","prefill_s","decode_s","decode_tok_s","total_tok_s"]
    header = " | ".join(f"{c:>11}" for c in cols)
    print(header)
    print("-" * len(header))
    for r in rows:
        print(" | ".join([
            f"{str(r['compile']):>11}",
            f"{r['backend'][:11]:>11}",
            f"{r['mode'][:11]:>11}",
            f"{str(r['kv_cache']):>11}",
            f"{r['prompt_toks']:>11}",
            f"{r['gen_toks']:>11}",
            f"{r['prefill_s']:>11.4f}",
            f"{r['decode_s']:>11.4f}",
            f"{r['decode_tok_s']:>11.4f}",
            f"{r['total_tok_s']:>11.2f}",
        ]))

# ----------- sweep config -----------
prompt = "Explain KV cache vs recomputing attention in 6 sentences.\n"

# Choose backends to try.
# - Debug/ablation backends: eager, aot_eager (good for isolating failures) :contentReference[oaicite:4]{index=4}
# - Performance backend: inductor (default) :contentReference[oaicite:5]{index=5}
# - TPU: openxla :contentReference[oaicite:6]{index=6}
if dev_kind == "xla":
    backends = ["openxla"]
else:
    backends = ["eager", "aot_eager", "inductor"]

modes = ["default", "reduce-overhead", "max-autotune"]  # trade compile time vs runtime :contentReference[oaicite:7]{index=7}

rows = []

# Eager baseline (no compile)
rows.append(bench_one(prompt, use_compile=False, use_cache=True,  runs=5, warmup=1))
rows.append(bench_one(prompt, use_compile=False, use_cache=False, runs=2, warmup=1))  # slow

print("Done Eager mode --------------------")

# Compile sweeps
for b in backends:
    for md in modes:
        print(f"Backend: {b},  Mode: {md} ------------------------------------------")
        rows.append(bench_one(prompt, use_compile=True, backend=b, mode=md, use_cache=True,  runs=5, warmup=1))

        rows.append(bench_one(prompt, use_compile=True, backend=b, mode=md, use_cache=False, runs=2, warmup=1))


print("\nAvailable torch.compile backends (if discoverable):", safe_list_backends())

print_rows(rows)


Device: cuda (kind: cuda )
Done Eage mode --------------------
Backend: eager,  Mode: default ------------------------------------------
Backend: eager,  Mode: reduce-overhead ------------------------------------------
Backend: eager,  Mode: max-autotune ------------------------------------------
Backend: aot_eager,  Mode: default ------------------------------------------
Backend: aot_eager,  Mode: reduce-overhead ------------------------------------------
Backend: aot_eager,  Mode: max-autotune ------------------------------------------
Backend: inductor,  Mode: default ------------------------------------------
Backend: inductor,  Mode: reduce-overhead ------------------------------------------
Backend: inductor,  Mode: max-autotune ------------------------------------------

Available torch.compile backends (if discoverable): ['cudagraphs', 'inductor', 'openxla', 'tvm']
    compile |     backend |        mode |    kv_cache | prompt_toks |    gen_toks |   prefill_s |    decode_s | d