In [None]:
import google.colab
import os
import sys
import subprocess

REPO_URL = "https://github.com/zivbeker42/llm-guesstimator.git"

bash_setup = f"""
setup_llm_guesstimator() {{
  if [[ -n \"${{COLAB_RELEASE_TAG:-}}\" ]]; then
    local repo_path=\"/content/llm-guesstimator\"
    if [[ ! -d \"$repo_path\" ]]; then
      git clone \"{REPO_URL}\" \"$repo_path\"
    fi
  fi
}}
setup_llm_guesstimator
"""

subprocess.run(["bash", "-lc", bash_setup])

repo_path = "/content/llm-guesstimator" if "google.colab" in sys.modules else os.path.abspath('.')

if repo_path not in sys.path:
    sys.path.append(repo_path)
os.chdir(repo_path)



# Prefill Time Estimates for Transformers - ($2·m·n·p$ convention)

This notebook derives and visualizes **prefill** (encode) time across $S$ concurrent requests,
counting matmuls as **$2mnp$ FLOPs**.

It includes:
- Compute and memory time formulas (with inline math in $...$).
- Parameters for NVIDIA **A100‑40GB**.
- Functions for $T_{prefill}^{compute}(S,L)$ and $T_{prefill}^{memory}(S,L)$.
- Plots vs $L$ and vs $S$.



## Formulas

**Compute FLOPs per layer & request (length $L$):**
- Projections $Q,K,V,O$: $8Ld^2$
- Attention ($QK^T$ and $AV$): $4L^2d$
- FFN: $4rLd^2$

So - 

$$FLOPS_{layer, req} = (8+4r)Ld^2 + 4L^2d$$

Across $S$ requests and $n_\ell$ layers:  
$$T_{prefill}^{compute}(S,L) = \dfrac{n_\ell S((8+4r)Ld^2 + 4L^2d)}{GPU\ FLOPs/s}$$

---

**Memory traffic per prefill round:**
1. Weights read: $n_\ell(4+2r)d^2 \times dtype\_bytes$  
2. KV write: $n_\ell(2SLd) \times dtype\_bytes$  
3. Activations: $n_\ell(c_{act}SLd) \times dtype\_bytes$

Total bytes:  
$n_\ell((4+2r)d^2 + (2+c_{act})SLd) \cdot dtype\_bytes$

So  
$T_{prefill}^{memory}(S,L) = \dfrac{n_\ell((4+2r)d^2 + (2+c_{act})SLd) \cdot dtype\_bytes}{BW}$



## Example Parameters (A100‑40GB, model Llama-70B)

- Hidden size $d = 4096$  
- FFN expansion $r = 4$  
- Layers $n_\ell = 64$  
- dtype\_bytes = 2 (FP16/BF16)  
- GPU peak FLOPs/s = $312 \times 10^{12}$  
- Bandwidth $BW \approx 1.555 \times 10^{12}$ Bytes/s (≈1555 GB/s)  


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from utils.config import S_L_GRID_SETTINGS, get_hardware_config, get_model_config
from utils.math_utils import prefill_compute_time, prefill_memory_time

hardware = get_hardware_config()
prefill_model = get_model_config("70B")
prefill_settings = S_L_GRID_SETTINGS["prefill"]


## Functions

## Plots vs $L$ for various $S$

In [None]:
start, stop, num = prefill_settings["context_linspace"]
L_vals = np.linspace(start, 10000, int(num))
S_list = prefill_settings["sample_batch_sizes"]

plt.figure(figsize=(8, 6))
for S in S_list:
    Tcomp = prefill_compute_time(S, L_vals, prefill_model, hardware) * 1e3
    plt.plot(L_vals, Tcomp, label=f'S={S} compute')
Tmem = prefill_memory_time(1, L_vals, prefill_model, hardware) * 1e3
plt.plot(L_vals, Tmem, '--', label='S=1 memory')
plt.xlabel("Context length L (tokens)")
plt.ylabel("Time (ms)")
plt.title("Prefill time vs L on A100-40GB")
plt.legend()
plt.grid(True, linestyle=":")
plt.show()


## Plots vs $S$ for fixed $L$

In [None]:
start, stop, num = prefill_settings["batch_linspace"]
S_vals = np.linspace(start, stop, int(num))
L_list = prefill_settings["sample_context_lengths"]
# L_list = [500, 1000, 2000, 4000]
plt.figure(figsize=(8, 6))
for L in L_list:
    Tcomp = prefill_compute_time(S_vals, L, prefill_model, hardware) * 1e3
    plt.plot(S_vals, Tcomp, label=f'L={L} compute')
Tmem = prefill_memory_time(S_vals, L_list[0], prefill_model, hardware) * 1e3
plt.plot(S_vals, Tmem, '--', label=f'L={L_list[0]} memory')
plt.xlabel("Concurrency S")
plt.ylabel("Time (ms)")
plt.title("Prefill time vs S on A100-40GB")
plt.legend()
plt.grid(True, linestyle=":")
plt.show()


## Sanity Check Table

In [None]:
rows = []
for S_L_dict in prefill_settings["summary_points"]:
    print(S_L_dict)
    S, L = S_L_dict["S"], S_L_dict["L"]
    rows.append({
        "S": S,
        "L": L,
        "T_compute": float(prefill_compute_time(S, L, prefill_model, hardware)),
        "T_memory": float(prefill_memory_time(S, L, prefill_model, hardware)),
    })

df = pd.DataFrame(rows)
df
