# Model-Dependent Tests for lib/kv_cache.py

These tests require a loaded model and run in a Jupyter kernel that already has
the model in GPU memory. They verify:

1. RoPE correction matches HuggingFace's actual rotary embedding
2. RoPE roundtrip is identity in float64, bounded error in float16
3. `apply_rope_roundtrip_noise` preserves BOS, introduces small noise
4. `correct_rope_positions_with_bos` preserves BOS, modifies doc keys
5. `score_answer_with_cache` NLL matches full-sequence forward pass
6. Truncated+corrected keys match bare cache keys
7. Truncated values differ from bare values (prefix priming works)
8. `build_truncated_kv_cache_corrected` produces correct length
9. `build_truncated_cache_variable_prefix` correctness
10. BPE boundary effect detection
11. Suffix cache passage portion matches bare cache

In [None]:
import sys, os, torch, numpy as np
sys.path.insert(0, '..')

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
from lib.config import ExperimentConfig
from lib.kv_cache import (
    _rotate_half, _build_rope_correction, _get_rope_theta, _ensure_dynamic_cache,
    _get_cache_keys, _get_cache_values, _set_cache_keys, _set_cache_values,
    build_kv_cache, extract_and_truncate_cache, extract_and_truncate_cache_with_bos,
    correct_rope_positions, correct_rope_positions_with_bos,
    build_hybrid_cache, swap_bos_entry, apply_rope_roundtrip_noise,
    replace_values_at_layers, score_answer_with_cache,
    build_truncated_kv_cache, build_truncated_kv_cache_corrected,
    build_truncated_cache_variable_prefix, build_suffix_kv_cache,
)

config = ExperimentConfig(num_samples=10, seed=42)

# Load model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type='nf4',
)
model_name = config.model_name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=bnb_config, device_map='auto',
)
model.eval()
NUM_LAYERS = model.config.num_hidden_layers
HEAD_DIM = model.config.hidden_size // model.config.num_attention_heads
ROPE_THETA = _get_rope_theta(model.config)

PASSAGE = (
    'The Amazon rainforest produces approximately 20 percent of the world\'s oxygen. '
    'It covers over 5.5 million square kilometers and is home to roughly 10 percent '
    'of all species on Earth.'
)
QUERY = 'how much oxygen does the amazon produce'
ANSWER = 'approximately 20 percent'

passed = 0
failed = 0
total = 0

def check(name, condition, detail=''):
    global passed, failed, total
    total += 1
    if condition:
        passed += 1
        print(f'  PASS: {name}  {detail}')
    else:
        failed += 1
        print(f'  FAIL: {name}  {detail}')

print(f'Model loaded: {model_name}')
print(f'Layers: {NUM_LAYERS}, Head dim: {HEAD_DIM}, RoPE theta: {ROPE_THETA}')

In [None]:
print('='*70)
print('TEST 1: RoPE correction matches HuggingFace rotary embedding')
print('='*70)

# Get the model's actual rotary embedding module
rotary_emb = model.model.rotary_emb

S = 20  # offset
# Generate cos/sin at specific positions using the model's own RoPE
dummy_x = torch.zeros(1, 1, 1, HEAD_DIM, device=config.device, dtype=torch.float16)
pos_1 = torch.tensor([[1]], device=config.device)
pos_1_plus_S = torch.tensor([[1 + S]], device=config.device)

cos_1, sin_1 = rotary_emb(dummy_x, pos_1)
cos_1s, sin_1s = rotary_emb(dummy_x, pos_1_plus_S)

# Apply HF RoPE at position 1 and 1+S to same pre-rope key
key_pre = torch.randn(1, 1, 1, HEAD_DIM, device=config.device, dtype=torch.float32)
c1, s1 = cos_1.squeeze().float(), sin_1.squeeze().float()
c1s, s1s = cos_1s.squeeze().float(), sin_1s.squeeze().float()

key_at_1 = key_pre * c1 + _rotate_half(key_pre) * s1
key_at_1_plus_S = key_pre * c1s + _rotate_half(key_pre) * s1s

# Apply our correction(-S) to key_at_1+S => should recover key_at_1
cos_corr, sin_corr = _build_rope_correction(S, HEAD_DIM, ROPE_THETA)
cos_corr = cos_corr.to(device=config.device, dtype=torch.float32)
sin_corr = sin_corr.to(device=config.device, dtype=torch.float32)

recovered = key_at_1_plus_S * cos_corr + _rotate_half(key_at_1_plus_S) * sin_corr
max_err = (recovered - key_at_1).abs().max().item()
# Tolerance is 1e-3 because HF stores inv_freq as a buffer (computed once)
# while we recompute from scratch — float32 accumulation differs slightly
check('RoPE correction vs HF rotary_emb', max_err < 1e-3, f'max_err={max_err:.2e}')

# Test multiple offsets
for S_test in [1, 5, 50, 200]:
    pos_target = torch.tensor([[3]], device=config.device)
    pos_shifted = torch.tensor([[3 + S_test]], device=config.device)
    c_t, s_t = rotary_emb(dummy_x, pos_target)
    c_s, s_s = rotary_emb(dummy_x, pos_shifted)
    
    k_t = key_pre * c_t.squeeze().float() + _rotate_half(key_pre) * s_t.squeeze().float()
    k_s = key_pre * c_s.squeeze().float() + _rotate_half(key_pre) * s_s.squeeze().float()
    
    cc, sc = _build_rope_correction(S_test, HEAD_DIM, ROPE_THETA)
    cc, sc = cc.to(device=config.device, dtype=torch.float32), sc.to(device=config.device, dtype=torch.float32)
    rec = k_s * cc + _rotate_half(k_s) * sc
    err = (rec - k_t).abs().max().item()
    check(f'  offset={S_test}', err < 1e-3, f'max_err={err:.2e}')

In [None]:
print('='*70)
print('TEST 2: RoPE roundtrip identity (float64) and bounded error (float16)')
print('='*70)

keys_f64 = torch.randn(1, 8, 10, HEAD_DIM, dtype=torch.float64)
offset = 15

# Forward: RoPE(+S) — use correction(-S)
cos_fwd, sin_fwd = _build_rope_correction(-offset, HEAD_DIM, ROPE_THETA)
cos_fwd, sin_fwd = cos_fwd.double(), sin_fwd.double()
rotated = keys_f64 * cos_fwd + _rotate_half(keys_f64) * sin_fwd

# Inverse: correction(+S)
cos_inv, sin_inv = _build_rope_correction(offset, HEAD_DIM, ROPE_THETA)
cos_inv, sin_inv = cos_inv.double(), sin_inv.double()
recovered = rotated * cos_inv + _rotate_half(rotated) * sin_inv

err_f64 = (recovered - keys_f64).abs().max().item()
# _build_rope_correction computes angles in float64 but returns float32,
# so roundtrip through double has float32-level precision
check('RoPE roundtrip identity (float64)', err_f64 < 1e-6, f'max_err={err_f64:.2e}')

# Float16
keys_f16 = torch.randn(1, 8, 10, HEAD_DIM, dtype=torch.float16, device=config.device)
cf, sf = _build_rope_correction(-offset, HEAD_DIM, ROPE_THETA)
cf, sf = cf.to(device=config.device, dtype=torch.float16), sf.to(device=config.device, dtype=torch.float16)
rotated_16 = keys_f16 * cf + _rotate_half(keys_f16) * sf

ci, si = _build_rope_correction(offset, HEAD_DIM, ROPE_THETA)
ci, si = ci.to(device=config.device, dtype=torch.float16), si.to(device=config.device, dtype=torch.float16)
recovered_16 = rotated_16 * ci + _rotate_half(rotated_16) * si

err_f16 = (recovered_16.float() - keys_f16.float()).abs().max().item()
check('RoPE roundtrip float16: nonzero error', err_f16 > 0, f'err={err_f16:.6f}')
check('RoPE roundtrip float16: bounded error', err_f16 < 0.1, f'err={err_f16:.6f}')

In [None]:
print('='*70)
print('TEST 3: apply_rope_roundtrip_noise')
print('='*70)

cache = DynamicCache()
keys_before = []
for li in range(NUM_LAYERS):
    k = torch.randn(1, 8, 10, HEAD_DIM, device=config.device, dtype=torch.float16)
    v = torch.randn(1, 8, 10, HEAD_DIM, device=config.device, dtype=torch.float16)
    keys_before.append(k.clone())
    cache.update(k, v, li)

bos_before = [_get_cache_keys(cache, li)[:, :, 0:1, :].clone() for li in range(NUM_LAYERS)]
apply_rope_roundtrip_noise(cache, offset=10, model=model)

bos_preserved = all(
    torch.allclose(_get_cache_keys(cache, li)[:, :, 0:1, :], bos_before[li], atol=0)
    for li in range(NUM_LAYERS)
)
check('roundtrip_noise: BOS preserved', bos_preserved)

max_errors = []
for li in range(NUM_LAYERS):
    doc_b = keys_before[li][:, :, 1:, :].float()
    doc_a = _get_cache_keys(cache, li)[:, :, 1:, :].float()
    max_errors.append((doc_a - doc_b).abs().max().item())

mean_max = np.mean(max_errors)
check('roundtrip_noise: nonzero perturbation', mean_max > 1e-5, f'mean_max_err={mean_max:.6f}')
check('roundtrip_noise: bounded perturbation', mean_max < 0.1, f'mean_max_err={mean_max:.6f}')

In [None]:
print('='*70)
print('TEST 4: correct_rope_positions_with_bos')
print('='*70)

# BOS preservation
cache4 = DynamicCache()
for li in range(NUM_LAYERS):
    k = torch.randn(1, 8, 10, HEAD_DIM, device=config.device, dtype=torch.float16)
    v = torch.randn(1, 8, 10, HEAD_DIM, device=config.device, dtype=torch.float16)
    cache4.update(k, v, li)

bos4 = [_get_cache_keys(cache4, li)[:, :, 0:1, :].clone() for li in range(NUM_LAYERS)]
doc4 = [_get_cache_keys(cache4, li)[:, :, 1:, :].clone() for li in range(NUM_LAYERS)]

correct_rope_positions_with_bos(cache4, offset=10, model=model)

bos_ok = all(
    torch.allclose(_get_cache_keys(cache4, li)[:, :, 0:1, :], bos4[li], atol=0)
    for li in range(NUM_LAYERS)
)
check('correct_rope_with_bos: BOS preserved', bos_ok)

doc_changed = all(
    not torch.allclose(_get_cache_keys(cache4, li)[:, :, 1:, :], doc4[li], atol=1e-6)
    for li in range(NUM_LAYERS)
)
check('correct_rope_with_bos: doc keys modified', doc_changed)

# Zero offset = noop
cache4b = DynamicCache()
k4b = torch.randn(1, 8, 10, HEAD_DIM, device=config.device, dtype=torch.float16)
v4b = torch.randn(1, 8, 10, HEAD_DIM, device=config.device, dtype=torch.float16)
cache4b.update(k4b.clone(), v4b.clone(), 0)
correct_rope_positions_with_bos(cache4b, offset=0, model=model)
check('correct_rope_with_bos: offset=0 is noop',
      torch.allclose(_get_cache_keys(cache4b, 0), k4b))

# correct_rope_positions (no BOS) modifies ALL positions
cache4c = DynamicCache()
k4c = torch.randn(1, 8, 10, HEAD_DIM, device=config.device, dtype=torch.float16)
cache4c.update(k4c.clone(), torch.randn_like(k4c), 0)
correct_rope_positions(cache4c, offset=10, model=model)
check('correct_rope (no BOS): modifies all positions including 0',
      not torch.allclose(_get_cache_keys(cache4c, 0), k4c, atol=1e-6))

In [None]:
print('='*70)
print('TEST 5: score_answer_with_cache NLL matches full forward pass')
print('='*70)

context = f'Document:\n{PASSAGE}'
ctx_len, cache5 = build_kv_cache(context, model, tokenizer, config)
query_prompt = config.query_template.format(query=QUERY)
nll_cached = score_answer_with_cache(
    cache5, ctx_len, query_prompt, ANSWER, model, tokenizer, config
)

# Full-sequence forward pass
full_text = context + query_prompt + ANSWER
full_ids = tokenizer(full_text, return_tensors='pt', add_special_tokens=True)['input_ids'].to(config.device)
answer_ids = tokenizer(ANSWER, return_tensors='pt', add_special_tokens=False)['input_ids'].to(config.device)
answer_len = answer_ids.shape[1]

with torch.no_grad():
    outputs = model(input_ids=full_ids, attention_mask=torch.ones_like(full_ids), return_dict=True)

answer_start = full_ids.shape[1] - answer_len
answer_logits = outputs.logits[:, answer_start:-1, :]
answer_labels = full_ids[:, answer_start + 1:]

loss_fct = torch.nn.CrossEntropyLoss(reduction='sum')
nll_full = loss_fct(
    answer_logits.contiguous().view(-1, answer_logits.size(-1)),
    answer_labels.contiguous().view(-1)
).item() / (answer_len - 1)

rel_err = abs(nll_cached - nll_full) / max(abs(nll_full), 1e-8)
# 4-bit quantized models may have small numerical differences between
# cached and full forward passes due to non-deterministic matmul ordering
check('NLL cached vs full', rel_err < 0.05,
      f'cached={nll_cached:.6f}, full={nll_full:.6f}, rel_err={rel_err:.6f}')

# Correct answer should have lower NLL
nll_wrong = score_answer_with_cache(
    cache5, ctx_len, query_prompt, 'purple elephants flying sideways',
    model, tokenizer, config
)
check('correct < wrong answer NLL', nll_cached < nll_wrong,
      f'correct={nll_cached:.4f}, wrong={nll_wrong:.4f}')

In [None]:
print('='*70)
print('TEST 6: Truncated+corrected keys match bare cache keys')
print('='*70)

prefix_text = 'Some irrelevant prefix text here. '
document_text = f'Document:\n{PASSAGE}'
prefix_with_sep = prefix_text + ' '

prefix_enc = tokenizer(prefix_with_sep, return_tensors='pt', add_special_tokens=True)
prefix_len = prefix_enc['input_ids'].shape[1]

full_context = prefix_with_sep + document_text
full_enc = tokenizer(full_context, return_tensors='pt', add_special_tokens=True)
full_ids = full_enc['input_ids'].to(config.device)
doc_len = full_ids.shape[1] - prefix_len

# Extract exact doc tokens for bare cache
doc_token_ids = full_ids[:, prefix_len:]
bos_id = full_ids[:, :1]
bare_ids = torch.cat([bos_id, doc_token_ids], dim=1)

with torch.no_grad():
    bare_out = model(input_ids=bare_ids, attention_mask=torch.ones_like(bare_ids),
                     use_cache=True, return_dict=True)
    full_out = model(input_ids=full_ids, attention_mask=torch.ones_like(full_ids),
                     use_cache=True, return_dict=True)

bare_cache = _ensure_dynamic_cache(bare_out.past_key_values)
trunc_cache = extract_and_truncate_cache_with_bos(full_out.past_key_values, doc_len)
offset = prefix_len - 1
correct_rope_positions_with_bos(trunc_cache, offset, model)

# Layer 0: input is pure token embeddings (identical regardless of prefix),
# so keys = RoPE(W_K * embed(token), pos). After correction, should match bare.
# NOTE: 4-bit quantized matmul produces slightly different results for different
# total sequence lengths, even for the same tokens. Tolerance reflects this.
bk0 = _get_cache_keys(bare_cache, 0)[:, :, 1:, :].float()
tk0 = _get_cache_keys(trunc_cache, 0)[:, :, 1:, :].float()
err_layer0 = (bk0 - tk0).abs().max().item()
check('layer 0 corrected keys match bare (4-bit tol)', err_layer0 < 0.1,
      f'max_err={err_layer0:.6f}')

# Layers >0: hidden states differ because doc tokens attended to prefix tokens,
# so pre-RoPE keys differ. RoPE correction only fixes positional encoding.
# We expect INCREASING divergence with layer depth.
max_errors = []
for li in range(NUM_LAYERS):
    bk = _get_cache_keys(bare_cache, li)[:, :, 1:, :].float()
    tk = _get_cache_keys(trunc_cache, li)[:, :, 1:, :].float()
    err = (bk - tk).abs().max().item()
    max_errors.append(err)

check('key divergence increases with depth',
      max_errors[0] < max_errors[-1],
      f'layer0={max_errors[0]:.6f}, last={max_errors[-1]:.6f}')

# BOS keys should be identical (position 0, same input embedding, causal = no attention)
bos_errs = []
for li in range(NUM_LAYERS):
    bk_bos = _get_cache_keys(bare_cache, li)[:, :, 0:1, :].float()
    tk_bos = _get_cache_keys(trunc_cache, li)[:, :, 0:1, :].float()
    bos_errs.append((bk_bos - tk_bos).abs().max().item())

mean_bos_err = np.mean(bos_errs)
check('BOS keys match between bare and trunc', mean_bos_err < 1e-4,
      f'mean_bos_err={mean_bos_err:.6f}')

In [None]:
print('='*70)
print('TEST 7: Values differ between bare and truncated (prefix priming)')
print('='*70)

divergences = []
for li in range(NUM_LAYERS):
    bv = _get_cache_values(bare_cache, li)[:, :, 1:, :].float()
    tv = _get_cache_values(trunc_cache, li)[:, :, 1:, :].float()
    l2 = torch.norm(bv - tv).item() / bv.numel()**0.5
    divergences.append(l2)

check('values differ (prefix priming works)', max(divergences) > 0.01,
      f'min={min(divergences):.6f}, max={max(divergences):.6f}, mean={np.mean(divergences):.6f}')

# Layer 0 values should differ least (minimal cross-attention effect)
# Later layers should diverge more
check('layer 0 divergence < last layer divergence',
      divergences[0] < divergences[-1],
      f'layer0={divergences[0]:.6f}, last={divergences[-1]:.6f}')

In [None]:
print('='*70)
print('TEST 8: build_truncated_kv_cache_corrected output length')
print('='*70)

surrogate = 'how much oxygen'
keep_len8, cache8 = build_truncated_kv_cache_corrected(
    surrogate, PASSAGE, model, tokenizer, config
)

surr_prefix = f'This document may be relevant to queries like: {surrogate}\n\n'
doc_text8 = f'Document:\n{PASSAGE}'
prefix_enc8 = tokenizer(surr_prefix, return_tensors='pt', add_special_tokens=True)
full_enc8 = tokenizer(surr_prefix + doc_text8, return_tensors='pt', add_special_tokens=True)
expected_doc_len = full_enc8['input_ids'].shape[1] - prefix_enc8['input_ids'].shape[1]

check('build_truncated_kv_cache_corrected: keep_len', keep_len8 == 1 + expected_doc_len,
      f'got={keep_len8}, expected={1+expected_doc_len}')

# Check cache is usable
nll8 = score_answer_with_cache(
    cache8, keep_len8, config.query_template.format(query=QUERY),
    ANSWER, model, tokenizer, config
)
check('build_truncated_kv_cache_corrected: valid NLL', np.isfinite(nll8) and nll8 > 0,
      f'nll={nll8:.6f}')

In [None]:
print('='*70)
print('TEST 9: build_truncated_cache_variable_prefix')
print('='*70)

prefix9 = 'Random prefix text here'
keep9, cache9, prefix_tok_len9 = build_truncated_cache_variable_prefix(
    prefix9, PASSAGE, model, tokenizer, config
)
prefix_enc9 = tokenizer(prefix9 + ' ', return_tensors='pt', add_special_tokens=True)
check('variable_prefix: prefix_token_len',
      prefix_tok_len9 == prefix_enc9['input_ids'].shape[1],
      f'got={prefix_tok_len9}, expected={prefix_enc9["input_ids"].shape[1]}')
check('variable_prefix: keep_len > 1', keep9 > 1)

# Different prefixes -> different caches
_, cacheA, _ = build_truncated_cache_variable_prefix(
    'Hello world', PASSAGE, model, tokenizer, config
)
_, cacheB, _ = build_truncated_cache_variable_prefix(
    'Completely different text about quantum physics and black holes',
    PASSAGE, model, tokenizer, config
)
v_a = _get_cache_values(cacheA, NUM_LAYERS - 1).float()
v_b = _get_cache_values(cacheB, NUM_LAYERS - 1).float()
min_len = min(v_a.shape[2], v_b.shape[2])
diff9 = (v_a[:, :, :min_len, :] - v_b[:, :, :min_len, :]).abs().mean().item()
check('variable_prefix: different prefixes -> different values', diff9 > 0.001,
      f'mean_diff={diff9:.6f}')

In [None]:
print('='*70)
print('TEST 10: BPE boundary effect detection')
print('='*70)

doc10 = 'Document:\nThe quick brown fox'
prefix10 = 'Some prefix text '

ids_alone = tokenizer(doc10, return_tensors='pt', add_special_tokens=True)['input_ids'][0]
prefix_ids10 = tokenizer(prefix10, return_tensors='pt', add_special_tokens=True)['input_ids'][0]
full_ids10 = tokenizer(prefix10 + doc10, return_tensors='pt', add_special_tokens=True)['input_ids'][0]
doc_ids_from_full = full_ids10[len(prefix_ids10):]
ids_alone_no_bos = ids_alone[1:]

match = torch.equal(ids_alone_no_bos, doc_ids_from_full)
print(f'  BPE tokens match: {match}')
print(f'    Alone (no BOS): {ids_alone_no_bos.tolist()[:10]}')
print(f'    From full:      {doc_ids_from_full.tolist()[:10]}')
if not match:
    print('  -> BPE splits DIFFER. build_matched_caches is NECESSARY.')
else:
    print('  -> BPE splits match for this example (may not hold for all inputs).')

# Test with a case more likely to show boundary effects
doc10b = 'Document:\nelectric'
prefix10b = 'an '
ids_alone_b = tokenizer(doc10b, return_tensors='pt', add_special_tokens=True)['input_ids'][0]
prefix_ids_b = tokenizer(prefix10b, return_tensors='pt', add_special_tokens=True)['input_ids'][0]
full_ids_b = tokenizer(prefix10b + doc10b, return_tensors='pt', add_special_tokens=True)['input_ids'][0]
doc_from_full_b = full_ids_b[len(prefix_ids_b):]
alone_no_bos_b = ids_alone_b[1:]

match_b = torch.equal(alone_no_bos_b, doc_from_full_b)
print(f'  Case 2 BPE match: {match_b}')
print(f'    Alone: {alone_no_bos_b.tolist()}')
print(f'    Full:  {doc_from_full_b.tolist()}')

In [None]:
print('='*70)
print('TEST 11: Suffix cache passage portion matches bare')
print('='*70)

# First: check determinism by running bare cache twice with same length
bare_len11a, bare_cache11a = build_kv_cache(PASSAGE, model, tokenizer, config)
bare_len11b, bare_cache11b = build_kv_cache(PASSAGE, model, tokenizer, config)
bare_cache11a = _ensure_dynamic_cache(bare_cache11a)
bare_cache11b = _ensure_dynamic_cache(bare_cache11b)

det_k_err = 0
det_v_err = 0
for li in range(NUM_LAYERS):
    k_err = (_get_cache_keys(bare_cache11a, li).float() - _get_cache_keys(bare_cache11b, li).float()).abs().max().item()
    v_err = (_get_cache_values(bare_cache11a, li).float() - _get_cache_values(bare_cache11b, li).float()).abs().max().item()
    det_k_err = max(det_k_err, k_err)
    det_v_err = max(det_v_err, v_err)

print(f'  Determinism (same length): max_k_err={det_k_err:.2e}, max_v_err={det_v_err:.2e}')
check('bare cache is deterministic (same length)', det_k_err < 1e-5 and det_v_err < 1e-5,
      f'k={det_k_err:.2e}, v={det_v_err:.2e}')

# Now test suffix causal invariance
suffix_len11, suffix_cache11 = build_suffix_kv_cache(
    PASSAGE, 'What is the oxygen percentage?', model, tokenizer, config
)
suffix_cache11 = _ensure_dynamic_cache(suffix_cache11)

max_k_err = 0
max_v_err = 0
for li in range(NUM_LAYERS):
    bk = _get_cache_keys(bare_cache11a, li).float()
    sk = _get_cache_keys(suffix_cache11, li)[:, :, :bare_len11a, :].float()
    bv = _get_cache_values(bare_cache11a, li).float()
    sv = _get_cache_values(suffix_cache11, li)[:, :, :bare_len11a, :].float()
    k_err = (bk - sk).abs().max().item()
    v_err = (bv - sv).abs().max().item()
    max_k_err = max(max_k_err, k_err)
    max_v_err = max(max_v_err, v_err)

# 4-bit quantized matmul produces different results for different total
# sequence lengths even with identical prefix tokens and causal masking.
# This is a known bitsandbytes limitation — the dequantization + GEMM
# batching differs. Tolerance of 0.15 accommodates this.
check('suffix passage matches bare (4-bit causal invariance)',
      max_k_err < 0.15 and max_v_err < 0.15,
      f'max_k_err={max_k_err:.2e}, max_v_err={max_v_err:.2e}')

# The errors should be small and uniform across layers (not growing)
# since it's quantization noise, not a systematic error
layer_k_errs = []
for li in range(NUM_LAYERS):
    bk = _get_cache_keys(bare_cache11a, li).float()
    sk = _get_cache_keys(suffix_cache11, li)[:, :, :bare_len11a, :].float()
    layer_k_errs.append((bk - sk).abs().max().item())

check('suffix errors are uniform (not growing with depth)',
      max(layer_k_errs) < 3 * np.mean(layer_k_errs),
      f'mean={np.mean(layer_k_errs):.4f}, max={max(layer_k_errs):.4f}')

In [None]:
print('\n' + '='*70)
print(f'SUMMARY: {passed}/{total} passed, {failed}/{total} failed')
print('='*70)
if failed > 0:
    print('\n*** FAILURES DETECTED - review output above ***')
else:
    print('\nAll tests passed.')