In this notebook, we showcase the memory and speed improvement obtained using the KV press pipeline (tested using A100 with 80GB memory).

In [1]:
import warnings
warnings.filterwarnings("ignore")
from time import time

import torch
from transformers import AutoModelForCausalLM

from kvpress import ExpectedAttentionPress

In [2]:
# Load model

device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
n_tokens = 128_000
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype="auto", attn_implementation="flash_attention_2").to(device)

initial_peak_memory = torch.cuda.max_memory_allocated()
print(f"Initial peak memory usage: {initial_peak_memory / 1024**3:.2f} GB")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Initial peak memory usage: 14.96 GB


In [3]:
# Compute cache size and prefiling time

inputs = torch.randint(0, model.config.vocab_size, (1, n_tokens)).to(device)

# Model warmup (for better prefilling time estimation)
with torch.no_grad():
    model(inputs[:, :100])
torch.cuda.synchronize()

# Compute cache size and prefilling time
start = time()
with torch.no_grad():
    cache = model(inputs, num_logits_to_keep=1).past_key_values
    cache_size = 2 * 2 * len(cache) * cache[0][0].numel() # 2 for keys and values, 2 for float16
    del cache
prefilling_time = time() - start

peak_memory = torch.cuda.max_memory_allocated()
torch.cuda.empty_cache()
print(f"Cache size for {n_tokens} tokens: {cache_size / 1024**3:.3f} GB")
print(f"Peak memory usage: {peak_memory / 1024**3:.2f} GB")
print(f"Peak memory w/o weights and KV cache: {(peak_memory - cache_size - initial_peak_memory) / 1024**3:.2f} GB")
print(f"Prefilling time: {prefilling_time:.2f}s")

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Cache size for 128000 tokens: 15.625 GB
Peak memory usage: 44.81 GB
Peak memory w/o weights and KV cache: 14.23 GB
Prefilling time: 28.56s


In [4]:
# Compute peak memory usage and generation time for different compression ratios

def timer(press, max_new_tokens = 100):
    start = time()
    with press(model):
        outputs = model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=None, top_p=None, pad_token_id=-1)
        assert outputs.shape == (1, n_tokens + max_new_tokens)
    return time() - start
    
for compression_ratio in [0.75, 0.5, 0.25, 0.0]:
    print(f"\nCompression ratio: {compression_ratio}")
    press = ExpectedAttentionPress(compression_ratio)
    torch.cuda.reset_peak_memory_stats()
    total_time = timer(press)
    print(f"Total time: {total_time:.2f}s. Estimated generation time: {total_time - prefilling_time:.2f}s")
    print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.



Compression ratio: 0.75


Total time: 35.04s. Estimated generation time: 6.48s
Peak memory usage: 33.10 GB

Compression ratio: 0.5
Total time: 36.27s. Estimated generation time: 7.71s
Peak memory usage: 37.00 GB

Compression ratio: 0.25
Total time: 37.63s. Estimated generation time: 9.08s
Peak memory usage: 40.91 GB

Compression ratio: 0.0
Total time: 37.92s. Estimated generation time: 9.36s
Peak memory usage: 44.82 GB
