<a href="https://colab.research.google.com/github/srv-sh/kv_caching/blob/main/PyramidKV.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**- Environment + installs**

In [2]:
%env HF_USE_LEGACY_CACHE=1
%pip -q install "torch>=2.3" "transformers==4.44.2" accelerate safetensors sentencepiece huggingface_hub




env: HF_USE_LEGACY_CACHE=1
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m49.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m42.1 MB/s[0m eta [36m0:00:00[0m
[?25h

**- Hugging Face login**

In [3]:
from huggingface_hub import login
login()  # paste your HF token


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

**— Load model & tokenizer**

Asserts CUDA is available; selects cuda device and torch.float16 dtype.

Sets model_id = "meta-llama/Llama-3.2-1B-Instruct".

Loads tokenizer (AutoTokenizer.from_pretrained) and model (AutoModelForCausalLM.from_pretrained), moves model to GPU, sets eval().

Tries to use fast attention (model.set_attn_implementation("sdpa")) for speed when not inspecting attentions.

Prints the number of transformer layers for reference.

In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

assert torch.cuda.is_available(), "Switch Colab runtime to GPU."

device = torch.device("cuda")
dtype  = torch.float16   # T4/A100: fp16 is safe

model_id = "meta-llama/Llama-3.2-1B-Instruct"  # requires HF access

tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=None).to(device).eval()

# start fast; we'll flip to eager only briefly when we need attentions
try: model.set_attn_implementation("sdpa")
except: pass

print("Loaded", model_id, "| layers:", model.config.num_hidden_layers)



tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

Loaded meta-llama/Llama-3.2-1B-Instruct | layers: 16


**— PyramidKV core**

Defines main logic and helpers:

Utilities for handling KV caches and memory usage.

make_pyramid_budgets() — per-layer KV budgets (bottom→top).

Progress printout showing recent text and KV stats.

prune_with_single_query_attn() — keeps last W + top-attention tokens within budget.

pyramidkv_decode_with_progress() — full generation loop with periodic pruning and progress display.

In [6]:
import os, torch, textwrap
from collections import deque

# Optional (helps surface errors synchronously while we debug)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def _is_tuple_cache(past): return isinstance(past, tuple)

def _get_L(past):
    return len(past) if _is_tuple_cache(past) else len(past.key_cache)

def _get_kv(past, i):
    if _is_tuple_cache(past): return past[i]
    return past.key_cache[i], past.value_cache[i]

def _set_kv(past, i, K, V):
    if _is_tuple_cache(past):
        raise RuntimeError("internal: tuple cache set attempted (we rebuild tuple instead)")
    past.key_cache[i] = K
    past.value_cache[i] = V

def _cache_lengths(past):
    if _is_tuple_cache(past):
        return [K.size(2) for (K,V) in past]
    return [K.size(2) for K in past.key_cache]

def _kv_bytes(past, dtype_bytes=2):
    total = 0
    if _is_tuple_cache(past):
        for K,V in past: total += (K.numel() + V.numel()) * dtype_bytes
    else:
        for K,V in zip(past.key_cache, past.value_cache):
            total += (K.numel() + V.numel()) * dtype_bytes
    return total

def make_pyramid_budgets(n_layers: int, bottom: int, top: int):
    if n_layers == 1: return [bottom]
    step = (bottom - top) / (n_layers - 1)
    return [int(round(bottom - i*step)) for i in range(n_layers)]

def _print_progress(step, generated_ids, past, budgets, tok, tail=200):
    text = tok.decode(generated_ids[-tail:], skip_special_tokens=True)
    lens = _cache_lengths(past)
    print(f"\n=== step {step} ===")
    print("▶ latest text:\n" + textwrap.fill(text, width=100))
    print("▶ cache lengths (bottom→top):")
    print(lens)
    print(f"   top layer len / budget: {lens[-1]} / {budgets[-1]}")
    print(f"   KV memory now: {_kv_bytes(past)/(1024*1024):.2f} MB")

@torch.no_grad()
def pyramidkv_decode_with_progress(
    prompt: str,
    *,
    max_new_tokens=128,
    window_W=128,          # always keep last W per layer
    bottom_budget=512,     # layer-0 budget
    top_budget=128,        # top-layer budget
    refresh_every=16,      # prune every N generated tokens
    start_refresh_after=256,
    per_head_select=True,  # per-head selection then union
    show_every_prune=True
):
    # budgets
    L = model.config.num_hidden_layers
    budgets = make_pyramid_budgets(L, bottom=bottom_budget, top=top_budget)

    # prefill fast
    try: model.set_attn_implementation("sdpa")
    except: pass
    x = tok(prompt, return_tensors="pt")
    x = {k: v.to(model.device) for k,v in x.items()}
    out = model(**x, use_cache=True, output_attentions=False)
    past = out.past_key_values
    generated = []
    hist = deque(x["input_ids"][0].tolist(), maxlen=window_W)

    def prune_with_single_query_attn(past, single_query_attns):
        """ keep last-W + top-(B-W) by single-query attention; supports tuple or DynamicCache """
        is_tuple = _is_tuple_cache(past)
        new_tuple = []  # only used if tuple path

        for i in range(_get_L(past)):
            K, V = _get_kv(past, i)              # [B, n_kv, T, Hd]
            T = K.size(2)
            B_i = max(budgets[i], window_W)
            if T <= B_i:
                if is_tuple: new_tuple.append((K,V))
                else: _set_kv(past, i, K, V)
                continue

            # trailing window
            win_start = max(0, T - window_W)
            win_idx = torch.arange(win_start, T, device=K.device)

            remain = B_i - len(win_idx)
            if remain <= 0:
                keep = win_idx
            else:
                # scores_h: [n_heads, T_total]
                A = single_query_attns[i]                     # [B, n_heads, 1, T]
                scores_h = A.squeeze(0).squeeze(1)           # [n_heads, T]
                # build candidates without in-place boolean indexing (avoids CUDA assert)
                all_idx = torch.arange(T, device=K.device)
                # exclude window via set-diff using isin
                non_win_mask = ~torch.isin(all_idx, win_idx)
                cand = all_idx[non_win_mask]
                k = min(remain, cand.numel())

                if k > 0:
                    if per_head_select:
                        head_keeps = []
                        for h in range(scores_h.size(0)):
                            vals = scores_h[h][cand]
                            topk = torch.topk(vals, k=k, largest=True).indices
                            head_keeps.append(cand[topk])
                        extra = torch.unique(torch.cat(head_keeps)).sort().values
                        keep = torch.cat([win_idx, extra]).unique().sort().values
                    else:
                        vals = scores_h.mean(dim=0)[cand]
                        topk = torch.topk(vals, k=k, largest=True).indices
                        extra = cand[topk]
                        keep = torch.cat([win_idx, extra]).unique().sort().values
                else:
                    keep = win_idx

            Kp = K.index_select(2, keep).contiguous()
            Vp = V.index_select(2, keep).contiguous()
            if is_tuple: new_tuple.append((Kp,Vp))
            else: _set_kv(past, i, Kp, Vp)

        return (tuple(new_tuple) if is_tuple else past)

    # decode loop
    cur = x["input_ids"][:, -1:]
    for t in range(max_new_tokens):
        try: model.set_attn_implementation("sdpa")
        except: pass
        out = model(input_ids=cur, use_cache=True, output_attentions=False, past_key_values=past)
        logits = out.logits[:, -1, :]
        cur = logits.argmax(dim=-1, keepdim=True)
        tok_id = cur.item()
        generated.append(tok_id)
        hist.append(tok_id)
        past = out.past_key_values

        long_enough = _cache_lengths(past)[0] >= start_refresh_after
        if (t+1) % refresh_every == 0 and long_enough:
            # one eager step to read single-query attentions (robust)
            try: model.set_attn_implementation("eager")
            except: pass
            outA = model(input_ids=cur, use_cache=True, output_attentions=True, past_key_values=past)
            attns = outA.attentions  # list of [B, n_heads, 1, T]
            # prune using these attentions, WITHOUT advancing past further
            past = prune_with_single_query_attn(past, attns)
            try: model.set_attn_implementation("sdpa")
            except: pass

            if show_every_prune:
                _print_progress(t+1, generated, past, budgets, tok)

    print("\n=== FINAL ===")
    _print_progress(max_new_tokens, generated, past, budgets, tok)
    return tok.decode(generated, skip_special_tokens=True), past, budgets



**- Test PyramidKV**

Creates a long repeated prompt.

Calls pyramidkv_decode_with_progress() with window/budget settings.

Prints sample output and final per-layer budgets.

In [7]:
prompt = " ".join(["In a distant library, scholars debated memory, context, and attention."] * 200)

text, past_kv, budgets = pyramidkv_decode_with_progress(
    prompt,
    max_new_tokens=128,
    window_W=128,
    bottom_budget=512,
    top_budget=128,
    refresh_every=16,
    start_refresh_after=256,
    per_head_select=True,
    show_every_prune=True
)



We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)



=== step 16 ===
▶ latest text:
 In a distant library, scholars debated memory, context, and attention. In a
▶ cache lengths (bottom→top):
[2203, 1959, 2128, 1745, 1712, 1520, 1440, 1353, 1243, 1360, 1309, 1036, 1002, 674, 437, 128]
   top layer len / budget: 128 / 128
   KV memory now: 41.50 MB

=== step 32 ===
▶ latest text:
 In a distant library, scholars debated memory, context, and attention. In a................
▶ cache lengths (bottom→top):
[1732, 1506, 1565, 1250, 1182, 1266, 1076, 937, 860, 825, 756, 658, 584, 412, 259, 128]
   top layer len / budget: 128 / 128
   KV memory now: 29.29 MB

=== step 48 ===
▶ latest text:
 In a distant library, scholars debated memory, context, and attention. In
a................................
▶ cache lengths (bottom→top):
[1582, 1317, 1352, 1052, 1025, 1037, 938, 831, 763, 684, 588, 520, 449, 343, 244, 128]
   top layer len / budget: 128 / 128
   KV memory now: 25.10 MB

=== step 64 ===
▶ latest text:
 In a distant library, scholars debated me

**— Chat comparison (Full vs PyramidKV)**

Implements a chat loop comparing standard KV vs pruned KV:

Sets hyperparams (W, bottom/top budgets, refresh rate).

Prefill helpers for both full and pruned paths.

_prune_with_single_query() — same rule as before (last W + top-attention).

chat() — generates replies for both modes, measures speed & KV memory, prints comparison summary.

In [20]:
# ========= Conversational A/B: Full vs PyramidKV (paper-like attention) =========
import time, torch, textwrap
from collections import deque

# ---- knobs (quality-friendly defaults) ----
W             = 160     # always-keep sliding window
BOTTOM_BUDGET = 640     # bottom-layer budget
TOP_BUDGET    = 128     # top-layer budget
REFRESH_EVERY = 24      # prune every N generated tokens
START_AFTER   = 256     # start pruning once cache length >= this
PREFILL_CHUNK = 512     # chunk size for extending caches

# ---- helpers ----
def _as_tuple_cache(past):
    if isinstance(past, tuple): return past
    if hasattr(past, "to_legacy_cache"): return past.to_legacy_cache()
    if hasattr(past, "values"):
        vals = list(past.values())
        if vals and isinstance(vals[0], tuple) and len(vals[0]) == 2:
            return tuple(vals)
    raise TypeError(f"Unsupported cache type: {type(past)}")

def _T(past): return past[0][0].size(2)  # current cached length
def kv_mb(past, bytes_per=2): return sum((K.numel()+V.numel())*bytes_per for K,V in past)/(1024*1024)

def make_pyramid(n_layers, bottom, top):
    if n_layers == 1: return [bottom]
    step = (bottom - top)/(n_layers - 1)
    return [int(round(bottom - i*step)) for i in range(n_layers)]

L = model.config.num_hidden_layers
BUDGETS = make_pyramid(L, BOTTOM_BUDGET, TOP_BUDGET)

# ---- system prompt & history rendering ----
SYS = "You are a helpful assistant. Be concise."
def _render(history):
    # simple chat transcript -> text prompt
    lines = [f"System: {SYS}"]
    for role, msg in history:
        lines.append(f"{role.capitalize()}: {msg}")
    lines.append("Assistant:")
    return "\n".join(lines)

# ---- extend caches with new tokens (chunked) ----
@torch.no_grad()
def _prefill_into_full(past, input_ids):
    i = 0
    while i < input_ids.size(1):
        ids = input_ids[:, i:i+PREFILL_CHUNK]
        out = model(input_ids=ids, use_cache=True, past_key_values=past)
        past = _as_tuple_cache(out.past_key_values)
        i += ids.size(1)
    return past

@torch.no_grad()
def _prefill_into_pyr_eager(past, input_ids):
    # force eager for the Pyramid path (stable with attention selection)
    try: model.set_attn_implementation("eager")
    except: pass
    i = 0
    while i < input_ids.size(1):
        ids = input_ids[:, i:i+PREFILL_CHUNK]
        T_prev = 0 if past is None else _T(past)
        cache_pos = torch.arange(T_prev, T_prev + ids.size(1), device=ids.device)
        out = model(input_ids=ids, use_cache=True, output_attentions=False,
                    past_key_values=past, cache_position=cache_pos)
        past = _as_tuple_cache(out.past_key_values)
        i += ids.size(1)
    return past

# ---- attention-based (paper-like) pruning: last-W + top-(B-W) by single-query attention ----
def _prune_with_single_query(past, attn_list, budgets=BUDGETS, window=W, per_head_union=True):
    new_layers = []
    for i, ((K,V), A) in enumerate(zip(past, attn_list)):
        T = K.size(2)
        B = max(budgets[i], window)
        if T <= B:
            new_layers.append((K,V)); continue

        win = torch.arange(max(0, T-window), T, device=K.device)   # last-W always kept
        remain = B - len(win)
        if remain <= 0:
            keep = win
        else:
            # A: [B, n_heads, 1, T] -> [n_heads, T]
            Sh = A.squeeze(0).squeeze(1)
            all_idx = torch.arange(T, device=K.device)
            cand = all_idx[~torch.isin(all_idx, win)]
            k = min(remain, cand.numel())
            if k > 0:
                if per_head_union:
                    head_keeps = []
                    for h in range(Sh.size(0)):
                        vals = Sh[h][cand]
                        head_keeps.append(cand[torch.topk(vals, k=k).indices])
                    extra = torch.unique(torch.cat(head_keeps)).sort().values
                else:
                    vals = Sh.mean(dim=0)[cand]
                    extra = cand[torch.topk(vals, k=k).indices]
                keep = torch.cat([win, extra]).unique().sort().values
            else:
                keep = win

        new_layers.append((K.index_select(2, keep).contiguous(),
                           V.index_select(2, keep).contiguous()))
    return tuple(new_layers)

# ---- conversation state ----
history   = []
past_full = None
pyr_past  = None

def reset_chat():
    global history, past_full, pyr_past
    history, past_full, pyr_past = [], None, None
    try: model.set_attn_implementation("sdpa")  # reset default for Full path
    except: pass
    print("Conversation reset.")

@torch.no_grad()
def chat(user_message, max_new_tokens=128, show_text=True, per_head_union=True):
    """
    Send one conversational turn.
    Full path: default attention (usually SDPA).
    Pyramid path: eager-only + single-query attention selection at prune steps.
    """
    global history, past_full, pyr_past

    # 1) render prompt with chat history
    history.append(("user", user_message))
    prompt_text = _render(history)
    ids = tok(prompt_text, return_tensors="pt")["input_ids"].to(model.device)

    # 2) extend caches (prefill)
    t0 = time.time(); past_full = _prefill_into_full(past_full, ids); t_full_prefill = time.time() - t0
    t0 = time.time(); pyr_past  = _prefill_into_pyr_eager(pyr_past,  ids); t_pyr_prefill  = time.time() - t0

    # 3) generate — FULL (normal fast path)
    gen_full = []; cur = ids[:, -1:]
    t0 = time.time()
    for _ in range(max_new_tokens):
        out = model(input_ids=cur, use_cache=True, past_key_values=past_full)
        past_full = _as_tuple_cache(out.past_key_values)
        cur = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        tok_id = cur.item(); gen_full.append(tok_id)
        if tok_id == tok.eos_token_id: break
    t_full = time.time() - t0

    # 4) generate — PYRAMID (eager-only + attention-based selection)
    try: model.set_attn_implementation("eager")
    except: pass
    gen_pyr = []; cur = ids[:, -1:]; steps = 0
    t0 = time.time()
    for _ in range(max_new_tokens):
        T_now = _T(pyr_past)
        cache_pos = torch.arange(T_now, T_now + 1, device=cur.device)
        out = model(input_ids=cur, use_cache=True, output_attentions=False,
                    past_key_values=pyr_past, cache_position=cache_pos)
        pyr_past = _as_tuple_cache(out.past_key_values)
        cur = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        tok_id = cur.item(); gen_pyr.append(tok_id)
        if tok_id == tok.eos_token_id: break

        steps += 1
        if steps >= REFRESH_EVERY and _T(pyr_past) >= START_AFTER:
            # single-query attention read (eager) for selection
            T_now = _T(pyr_past)
            cache_pos = torch.arange(T_now, T_now + 1, device=cur.device)
            outA = model(input_ids=cur, use_cache=True, output_attentions=True,
                         past_key_values=pyr_past, cache_position=cache_pos)
            attns = outA.attentions  # list of [B, n_heads, 1, T]
            pyr_past = _prune_with_single_query(pyr_past, attns, budgets=BUDGETS, window=W,
                                                per_head_union=per_head_union)
            steps = 0
    t_pyr = time.time() - t0

    # 5) decode, commit, metrics
    ans_full = tok.decode(gen_full, skip_special_tokens=True).strip()
    ans_pyr  = tok.decode(gen_pyr,  skip_special_tokens=True).strip()
    history.append(("assistant", ans_full))  # keep transcript natural

    mb_full = kv_mb(past_full); mb_pyr = kv_mb(pyr_past)
    retention = 100*mb_pyr/max(1e-9, mb_full)
    tps_full = len(gen_full)/max(1e-6, t_full)
    tps_pyr  = len(gen_pyr)/max(1e-6, t_pyr)

    print("\n================ FULL ================")
    if show_text: print(textwrap.fill(ans_full, width=100))
    print(f"\nTime: {t_full:.2f}s | {tps_full:.1f} tok/s | KV: {mb_full:.2f} MB")

    print("\n================ PYRAMIDKV (attention-based) ================")
    if show_text: print(textwrap.fill(ans_pyr, width=100))
    print(f"\nTime: {t_pyr:.2f}s | {tps_pyr:.1f} tok/s | KV: {mb_pyr:.2f} MB")

    print("\n---------------- SUMMARY ----------------")
    print(f"Budgets (bottom→top): {BUDGETS[0]} … {BUDGETS[-1]}  |  W={W}  |  refresh_every={REFRESH_EVERY}  |  start_after={START_AFTER}")
    print(f"Retention: {retention:.1f}%   (smaller is better)")
    print(f"Prefill(s) — Full: {t_full_prefill:.2f}  |  Pyr: {t_pyr_prefill:.2f}")

    return {"full_text": ans_full, "pyr_text": ans_pyr,
            "full_time": t_full, "pyr_time": t_pyr,
            "full_mb": mb_full, "pyr_mb": mb_pyr, "retention_pct": retention}

# ---- usage examples ----
# reset_chat()
# chat("Summarize KV caching and why it speeds up long chats in 4 bullets.")
# chat("Give a simple analogy, then a 5-line Python pseudo-code example.")
# chat("Now, explain what PyramidKV changes relative to a full cache, in one short paragraph.")


**— Chat demo**

Runs 3 chat turns testing both modes:

KV cache summary.

Analogy + small code example.

One-paragraph PyramidKV explanation.

Shows outputs, timing, and memory usage side-by-side.

In [21]:
# start fresh
reset_chat()

# turn 1
chat("Summarize KV caching and why it speeds up long chats in 4 bullets.")

# turn 2
chat("Give a simple analogy, then a 5-line Python pseudo-code example.")

# turn 3
chat("Now, explain what PyramidKV changes compared to a full cache, in one short paragraph.")


Conversation reset.

• Key-value (KV) caching is a technique used to store frequently accessed data in memory, reducing
the need for disk I/O. • It speeds up long chats by reducing the time it takes to access data, as it
can retrieve data from memory instead of loading it from disk. • KV caching is particularly
effective for large datasets, as it can store and retrieve data in a highly efficient manner. • By
reducing the time it takes to access data, KV caching can significantly improve the overall
performance of long chats, making it a valuable technique for applications that require frequent
data access.

Time: 2.26s | 52.2 tok/s | KV: 4.72 MB

• Key-value (KV) caching is a technique used to store frequently accessed data in memory, reducing
the need for disk I/O. • It speeds up long chats by reducing the time it takes to access data, as it
can retrieve data from memory instead of loading it from disk. • KV caching is particularly
effective for large datasets, as it can store and ret

{'full_text': 'PyramidKV is a variant of key-value caching that uses a hierarchical structure, where each key is a prefix of another key. This allows for more efficient storage and retrieval of data, as it can store data at multiple levels of granularity. PyramidKV is particularly useful for applications that require complex data organization and querying. It is also more scalable and flexible than traditional key-value caching, as it can handle large amounts of data and complex queries. Overall, PyramidKV is a powerful caching solution that offers improved performance and flexibility compared to traditional key-value caching.',
 'pyr_text': 'PyramidKV is a variation of the traditional cache that uses a hierarchical structure, where the cache is divided into multiple levels, allowing for faster lookup and retrieval of data. PyramidKV is designed to handle large amounts of data and provides better performance compared to traditional caches, especially in scenarios where data is distribu