In [1]:
import transformers
import torch
import time
import shutil
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache

# Load the model
ckpt = "models/Llama-3-8B-Instruct-Gradient-1048k"
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="eager",
).to("cuda")


generation_config = GenerationConfig.from_pretrained(ckpt)
eos_token_ids = generation_config.eos_token_id
if not isinstance(eos_token_ids, list):
    eos_token_ids = [eos_token_ids]

# add some tokens like "</user>" and </s> to eos ids
eos_token_ids += tokenizer.encode("</user>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</s>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</", add_special_tokens=False)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
from duo_attn.utils import load_attn_pattern, sparsify_attention_heads
from duo_attn.patch import enable_duo_attention_eval

# Load the attention pattern
attn_heads, sink_size, recent_size = load_attn_pattern(
    "attn_patterns/Llama-3-8B-Instruct-Gradient-1048k/lr=0.02-reg=0.05-ctx=1000_32000-multi_passkey10"
)

print(attn_heads.shape)
print(sink_size)
print(recent_size)

# Sparsify attention heads
attn_heads, sparsity = sparsify_attention_heads(attn_heads, sparsity=0.5)

print(attn_heads, sparsity)

enable_duo_attention_eval(
    model,
    attn_heads,
    sink_size=64,
    recent_size=256,
)

(32, 8)
128
256
[[0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 1. 0.]
 [0. 0. 0. 0. 0. 1. 1. 0.]
 [0. 0. 0. 0. 1. 1. 0. 0.]
 [0. 1. 1. 0. 1. 0. 0. 1.]
 [0. 0. 0. 0. 0. 1. 1. 0.]
 [1. 0. 0. 1. 0. 1. 1. 0.]
 [1. 1. 1. 0. 1. 0. 1. 1.]
 [0. 0. 0. 1. 0. 1. 1. 1.]
 [1. 0. 0. 1. 1. 0. 1. 1.]
 [1. 1. 0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0. 1. 0. 0.]
 [1. 1. 1. 0. 1. 0. 1. 1.]
 [0. 1. 0. 0. 1. 1. 1. 1.]
 [1. 0. 1. 0. 1. 0. 1. 1.]
 [1. 1. 1. 0. 0. 1. 1. 0.]
 [1. 1. 0. 1. 0. 1. 1. 1.]
 [0. 1. 0. 1. 0. 1. 0. 0.]
 [1. 0. 1. 1. 0. 1. 0. 1.]
 [1. 0. 1. 1. 1. 1. 1. 0.]
 [1. 0. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 1. 0. 0. 1.]
 [1. 1. 0. 1. 0. 1. 1. 1.]
 [1. 0. 0. 0. 1. 1. 1. 0.]
 [1. 1. 0. 1. 0. 1. 0. 0.]
 [1. 0. 0. 1. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0. 1. 1. 1.]
 [1. 1. 0. 1. 1. 1. 1. 0.]
 [0. 0. 1. 0. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 0. 0. 0. 0. 1. 0. 1.]] 0.5
Enabling DuoAttention evaluation using sink size 64 and recent size 256
Enabling tuple KV cache for Llama


In [5]:
# Manually perform inference using KV cache

inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)
max_new_tokens = 23
generated_tokens = []
input_ids = inputs["input_ids"]

# Initialize past_key_values to None
past_key_values = None

for _ in range(max_new_tokens):
    with torch.no_grad():
        outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
        
        # Extract the logits and past_key_values (the cache)
        next_token_logits = outputs.logits[:, -1, :]  # Logits of the last token
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step

        print(past_key_values)

        # Greedy decoding: get the token with the highest probability
        next_token = torch.argmax(next_token_logits, dim=-1)
        generated_tokens.append(next_token.item())

        # Only pass the new token for the next iteration
        input_ids = next_token.unsqueeze(-1)

# Convert generated token ids to text
output_text = "Fun fact: The shortest" + tokenizer.decode(generated_tokens, skip_special_tokens=True)
print(output_text)

((tensor([[[[-5.0049e-02,  1.5918e-01, -9.5215e-03,  ...,  1.8652e-01,
            1.4766e+00, -5.5078e-01],
          [-9.1406e-01, -6.4062e-01,  1.0469e+00,  ...,  2.7148e-01,
           -1.4531e+00,  9.8438e-01],
          [-8.1250e-01, -7.8125e-02,  7.0312e-01,  ...,  1.1641e+00,
           -1.9766e+00,  4.5312e-01],
          [ 1.2422e+00, -3.5312e+00,  4.8750e+00,  ...,  1.2266e+00,
           -1.7891e+00,  1.5703e+00],
          [ 8.9062e-01, -5.9082e-02,  4.8828e-01,  ..., -1.3359e+00,
           -5.3516e-01,  8.3984e-01],
          [-1.0645e-01,  1.1016e+00, -6.2500e-01,  ...,  1.1172e+00,
           -2.6406e+00,  8.6328e-01]]],


        [[[ 2.1820e-03, -5.1575e-03, -2.9297e-03,  ..., -5.4932e-04,
            2.3499e-03, -5.0049e-03],
          [-5.1025e-02,  1.3550e-02, -3.2471e-02,  ...,  1.8555e-02,
            4.1199e-04, -8.7280e-03],
          [ 1.0071e-02, -3.3691e-02, -4.0771e-02,  ...,  1.8433e-02,
           -1.1230e-01, -8.5449e-02],
          [-1.1353e-02,  7.6904

In [4]:
from duo_attn.patch.static_kv_cache import DuoAttentionStaticKVCache


context = "A quick brown fox jumps over the lazy dog. \n"
with open("demo/duo_attention.txt", "r") as f:
    needle = f.read()

insertion_point = 0.75
num_tokens_context = len(tokenizer.encode(context, add_special_tokens=False))
num_repetitions = 100 // num_tokens_context
text = (
    "This is a very long story book: <book> "
    + context * int(num_repetitions * insertion_point)
    + needle
    + context * int(num_repetitions * (1 - insertion_point))
    + "</book>\n Based on the content of the book, please briefly tell me about DuoAttention.\nAnswer:"
)


def generate_with_kv_cache(model, kv_cache, pred_token_idx, eos_token_ids, tokenizer):
    total_latency = 0
    generated_content = [pred_token_idx.item()]
    # This will keep track of how many lines were previously printed
    previous_lines = 0
    print("Generated text (Mem: N/A | Time: N/A):", end=" ", flush=True)
    
    for _ in range(500):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        output = model(
            input_ids=pred_token_idx,
            kv_cache=kv_cache,
        )
        end.record()
        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)
        total_latency += elapsed_time

        pred_token_idx = output.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        if pred_token_idx.item() in eos_token_ids:
            break
        generated_content += [pred_token_idx.item()]
        # Capture memory usage using torch.cuda.max_memory_allocated()
        used_mem = torch.cuda.max_memory_allocated() / (1024**3)  # Convert to GB
        latency_per_token = total_latency / (
            len(generated_content) - 1
        )  # Latency in ms

        generated_text = tokenizer.decode(
            generated_content,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
            spaces_between_special_tokens=False,
        ).strip()

        output = f"Decoding (Mem: {used_mem:.1f} GB | Latency: {latency_per_token:.1f} ms/tok): {generated_text}"

        # Get the terminal width dynamically
        terminal_width = shutil.get_terminal_size().columns

        # Count how many lines the generated text will take up based on terminal width
        lines = (len(output) + terminal_width - 1) // terminal_width - 1

        # Clear the previous output lines
        print("\r" + "\033[K", end="")  # Clear current line
        for _ in range(previous_lines):
            print("\033[F\033[K", end="")  # Move cursor up and clear line

        # Print the updated text
        print(
            output,
            end="",
            flush=True,
        )

        # Update the previous_lines to the new count
        previous_lines = lines

    print(
        f"\n\nPer-token decoding latency: {total_latency / (len(generated_content) - 1):.2f} ms"
    )
    return tokenizer.decode(generated_content, skip_special_tokens=False).strip()
    
def test_with_chunked_prefilling(chunk_size=32000):
    
    kv_cache = DynamicCache
    start_time = time.time()
    with torch.no_grad():
        pbar = tqdm(
            range(0, input_ids.size(1), chunk_size),
            desc=f"Pre-filling ({0}/{input_ids.size(1)})",
        )
        for i in pbar:
            chunk_input_ids = input_ids[:, i : i + chunk_size]
            print(chunk_input_ids.shape)
            output = model(
                input_ids=chunk_input_ids,
                kv_cache=kv_cache,
            )
        pbar.close()
    end_time = time.time()
    print(f"Pre-filling time: {end_time - start_time:.2f}s\n")
    pred_token_idx = output.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    O = generate_with_kv_cache(
        model, kv_cache, pred_token_idx, eos_token_ids, tokenizer
    )

    print(O)

input_ids = tokenizer.encode(text, return_tensors="pt").to("cuda")

print(f"Input sequence length: {input_ids.size(1)}\n")
print(text)
torch.cuda.reset_peak_memory_stats()
used_mem = torch.cuda.max_memory_allocated()
test_with_chunked_prefilling(32000)
print(f"Peak memory: {used_mem / 1024 ** 3:.2f} GB")

Input sequence length: 462

This is a very long story book: <book> A quick brown fox jumps over the lazy dog. 
A quick brown fox jumps over the lazy dog. 
A quick brown fox jumps over the lazy dog. 
A quick brown fox jumps over the lazy dog. 
A quick brown fox jumps over the lazy dog. 
A quick brown fox jumps over the lazy dog. 
Title: DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads
Abstract: Deploying long-context large language models (LLMs) is essential but poses significant computational and memory challenges. Caching all Key and Value (KV) states across all attention heads consumes substantial memory. Existing KV cache pruning methods either damage the long-context capabilities of LLMs or offer only limited efficiency improvements. In this paper, we identify that only a fraction of attention heads, a.k.a, Retrieval Heads, are critical for processing long contexts and require full attention across all tokens. In contrast, all other heads, which

Pre-filling (0/462):   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 462])





TypeError: old_llama_for_causal_lm_forward() got an unexpected keyword argument 'kv_cache'