In this notebook, we showcase how to use the KVpress pipelines by answering questions about NVIDIA Wikipedia article.

In [1]:
import requests
from bs4 import BeautifulSoup

import torch
from transformers import pipeline

from kvpress import (
    ExpectedAttentionPress,
    KnormPress,
    ObservedAttentionPress,
    RandomPress,
    SnapKVPress,
    StreamingLLMPress,
    TOVAPress,
)

In [2]:
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import contextlib
import logging
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline, QuantizedCache
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.pipelines.base import GenericTensor

from kvpress.presses.base_press import BasePress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress

logger = logging.getLogger(__name__)


class KVPressTextGenerationPipeline(Pipeline):
    """
    Pipeline for key-value cache compression in causal language models.

    Enables efficient processing of long contexts by applying KV cache compression
    during pre-filling, then generating answers using greedy decoding.

    Example:
    ```python
    pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)
    press = SnapKVPress(compression_ratio=0.5)
    result = pipeline(context="Long text...", question="A question about the long context.", press=press)
    ```
    """

    def _sanitize_parameters(
        self,
        question: Optional[str] = None,
        questions: Optional[list[str]] = None,
        answer_prefix: Optional[str] = None,
        press: Optional[BasePress] = None,
        max_new_tokens: int = 50,
        max_context_length: Optional[int] = None,
        cache: Optional[Cache] = None,
        **kwargs,
    ):
        """
        Sanitize the input parameters for the pipeline.
        The user can either provide a single question or a list of questions to be asked about the context.

        Parameters
        ----------
        question : str, optional
            The question to be asked about the context. Exclusive with `questions`.
        questions : list[str], optional
            A list of questions to be asked about the context. Exclusive with `question`.
        answer_prefix : str, optional
            The prefix to be added to the generated answer.
        press : BasePress, optional
            The key-value cache compression method to apply during pre-filling.

            Accepts any KVPress compression method (SnapKVPress, KnormPress,
            ExpectedAttentionPress, BlockPress, AdaKVPress, ComposedPress, etc.).
            If None, no compression is applied.
        max_new_tokens : int, optional
            The maximum number of new tokens to generate for each answer.
        max_context_length : int, optional
            The maximum number of tokens in the context. By default will use the maximum length supported by the model.
        cache : Cache, optional
            The cache to use for the forward pass. Defaults to None (DynamicCache).
        **kwargs : dict
            Additional keyword arguments, currently ignored.

        Returns
        -------
        Tuple[dict, dict, dict]
            A tuple containing three dictionaries:
                - preprocess_kwargs: The keyword arguments for the preprocess function.
                - forward_kwargs: The keyword arguments for the forward function.
                - postprocess_kwargs: The keyword arguments for the postprocess function.
        """

        answer_prefix = answer_prefix or ""
        postprocess_kwargs = {"single_question": questions is None}
        assert question is None or questions is None, "Either question or questions should be provided, not both."
        questions = questions or ([question] if question else [""])
        if max_context_length is None:
            max_context_length = min(self.tokenizer.model_max_length, int(1e10))  # 1e10 to avoid overflow
        preprocess_kwargs = {
            "questions": questions,
            "answer_prefix": answer_prefix,
            "max_context_length": max_context_length,
        }
        forward_kwargs = {"press": press, "max_new_tokens": max_new_tokens, "cache": cache}
        return preprocess_kwargs, forward_kwargs, postprocess_kwargs

    def preprocess(
        self,
        context: str,
        questions: list[str],
        answer_prefix: str,
        max_context_length: int,
    ):
        """
        Apply chat template and tokenize the context and questions.

        Prepares input text for KV cache compression and generation by applying
        appropriate chat templates and tokenizing. Handles models with and without
        chat templates.

        Parameters
        ----------
        context : str
            Long context text to be compressed using the press method.
        questions : list[str]
            Questions to be asked about the context.
        answer_prefix : str
            Optional prefix for generated answers.
        max_context_length : int
            Maximum tokens allowed in context (truncated if exceeded).

        Returns
        -------
        dict[str, GenericTensor]
            Dictionary with "context_ids" and "questions_ids" tensors.
        """

        # Apply chat template if available
        if self.tokenizer.chat_template is None:
            bos_token = getattr(self.tokenizer, "bos_token", "")
            context = bos_token + context
            question_suffix = "\n"  # to separate the question from the answer
        else:
            separator = "\n" + "#" * len(context)
            context = self.tokenizer.apply_chat_template(
                [{"role": "user", "content": context + separator}],
                add_generation_prompt=True,
                tokenize=False,
                enable_thinking=False,
            )
            context, question_suffix = context.split(separator)

        # Add question_suffix and answer prefix
        # e.g. for llama3.1, question_suffix="<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")
        questions = [question + question_suffix + answer_prefix for question in questions]

        # Tokenize the context and questions
        context_ids = self.tokenizer.encode(context, return_tensors="pt", add_special_tokens=False)
        question_ids = [
            self.tokenizer.encode(question, return_tensors="pt", add_special_tokens=False) for question in questions
        ]

        # Truncate context
        if context_ids.shape[1] > max_context_length:
            logger.warning(
                f"Context length has been truncated from {context_ids.shape[1]} to {max_context_length} tokens."
            )
            context_ids = context_ids[:, :max_context_length]

        return {"context_ids": context_ids, "questions_ids": question_ids}

    def _forward(
        self,
        input_tensors: dict[str, GenericTensor],
        max_new_tokens: int = 50,
        press: Optional[BasePress] = None,
        cache: Optional[Cache] = None,
    ):
        """
        Execute KV cache compression and text generation pipeline.

        Performs context compression using the press method during pre-filling,
        then generates answers using greedy decoding.

        Parameters
        ----------
        input_tensors : dict[str, GenericTensor]
            Tokenized inputs with "context_ids" and "questions_ids".
        max_new_tokens : int, default=50
            Maximum tokens to generate for each answer.
        press : BasePress, optional
            Compression method for context pre-filling. If None, no compression.
        cache : Cache, optional
            Cache object for forward pass. If None, creates new DynamicCache.

        Returns
        -------
        list[str]
            Generated answers for each input question.
        """
        context_ids = input_tensors["context_ids"].to(self.model.device)
        context_length = context_ids.shape[1]

        if cache is None:
            cache = DynamicCache()

        with press(self.model) if press is not None else contextlib.nullcontext():
            self.model.model(
                input_ids=context_ids,
                past_key_values=cache,
                output_attentions=self.output_attentions(press),
            )

        logger.debug(f"Context Length: {context_length}")
        logger.debug(f"Compressed Context Length: {cache.get_seq_length()}")

        # Greedy decoding for each question
        answers = []
        for question_ids in input_tensors["questions_ids"]:
            if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys):
                context_length = cache.get_seq_length()

            answer = self.generate_answer(
                question_ids=question_ids.to(self.model.device),
                cache=cache,
                context_length=context_length,
                max_new_tokens=max_new_tokens,
            )
            answers.append(answer)

        return answers


    def generate_answer(
        self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int
    ) -> str:
        """
        Generate an answer to a question using greedy decoding.

        Parameters
        ----------
        question_ids : torch.Tensor
            The tokenized question.
        cache : Cache
            The compressed key-value cache.
        context_length : int
            The length of the context.
        max_new_tokens : int
            The maximum number of new tokens to generate.

        Returns
        -------
        str
            The generated answer.
        """
        cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))]
        
        input_ids = question_ids.to(self.model.device)
        position_ids = torch.arange(
            cache.get_seq_length(), cache.get_seq_length() + question_ids.shape[1], device=self.model.device
        ).unsqueeze(0)

        # Prepare model kwargs for the initial forward pass
        model_kwargs = dict(num_logits_to_keep=1,
                            past_key_values=cache,
                            use_cache=True,
                            position_ids=position_ids
                           )

        if self._is_flash_attention_enabled():
            self._update_kwargs_for_flash_attention(model_kwargs, 
                                                    cache.get_seq_length(), 
                                                    question_ids.shape[1]
                                                   )

        # Initial forward pass
        outputs = self.model(input_ids=input_ids, **model_kwargs)

        position_ids = position_ids[:, -1:] + 1
        generated_ids = [outputs.logits[0, -1].argmax()]

        should_stop_token_ids = self.model.generation_config.eos_token_id
        if not isinstance(should_stop_token_ids, list):
            should_stop_token_ids = [should_stop_token_ids]

        for i in range(max_new_tokens - 1):
            input_ids = generated_ids[-1].unsqueeze(0).unsqueeze(0)
            model_kwargs = dict(num_logits_to_keep=1, past_key_values=cache, use_cache=True,
                                position_ids=position_ids + i
                               )

            if self._is_flash_attention_enabled():
                self._update_kwargs_for_flash_attention(model_kwargs, 
                                                        cache.get_seq_length(),
                                                        1)

            outputs = self.model(input_ids=input_ids, **model_kwargs)
            new_id = outputs.logits[0, -1].argmax()
            generated_ids.append(new_id)
            if new_id.item() in should_stop_token_ids:
                break
        answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True)

        # Remove the generated tokens from the cache
        for layer_idx, sequence_length in enumerate(cache_seq_lengths):
            cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length]
            cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length]

        if isinstance(cache, QuantizedCache):
            for layer_idx, sequence_length in enumerate(cache_seq_lengths):
                cache.cache_processor._quantized_keys[layer_idx] = cache.cache_processor._quantized_keys[layer_idx][
                    :, :, :sequence_length
                ]
                cache.cache_processor._quantized_values[layer_idx] = cache.cache_processor._quantized_values[layer_idx][
                    :, :, :sequence_length
                ]

        return answer

    def output_attentions(self, press: BasePress):
        if isinstance(press, ObservedAttentionPress):
            return True
        if isinstance(press, (KeyRerotationPress, PerLayerCompressionPress)) and isinstance(
            press.press, ObservedAttentionPress
        ):
            return True
        return False

    def postprocess(self, model_outputs, single_question):
        if single_question:
            return {"answer": model_outputs[0]}
        return {"answers": model_outputs}

    def _is_flash_attention_enabled(self) -> bool:
        """Check if flash attention is enabled for the model."""
        return "flash" in self.model.config._attn_implementation and self.model._supports_attention_backend
        
    def _update_kwargs_for_flash_attention(
        self, model_kwargs: dict, cache_sequence_length: int, new_seq_len: int = 1
    ) -> None:
        """Update model kwargs with flash attention specific parameters."""
        #return 
        # DO NOT UPDATE, AS IT WOULD BREAK LAYER-WISE COMPRESSION METHODS
        model_kwargs.update(
            cu_seq_lens_q=torch.tensor([0, new_seq_len], dtype=torch.int32, device=self.device),
            cu_seq_lens_k=torch.tensor([0, cache_sequence_length], dtype=torch.int32, device=self.device),
            max_length_q=new_seq_len,
            max_length_k=cache_sequence_length,
        )


PIPELINE_REGISTRY.register_pipeline(
    "kv-press-text-generation-debug",
    pipeline_class=KVPressTextGenerationPipeline,
    pt_model=AutoModelForCausalLM,
)


# Load the pipeline and data

In [3]:
"""
    # If the lengths are not equal, most probably we are in decoding stage with cache
    # In that case the position ids will not always start with `0` and we need a better way to infer
    # cumulative seq lengths.
    if query_length != kv_length:
        indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)

        tensor_kws = {"dtype": torch.int32, "device": position_ids.device}
        last_position_ids = position_ids[:, -1]

        cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), last_position_ids.cumsum(0).add(1)], 0)
        max_length_k = int(last_position_ids.max()) + 1
"""

file_path = '/mnt/2tb/PyCharmProjects/kvpress/.venv/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py'
start_line = 227
num_lines = 20

with open(file_path, 'r') as f:
    lines = f.readlines()
    for i, line in enumerate(lines[start_line - 1 : start_line - 1 + num_lines], start=start_line):
        print(f"{i:4}: {line.rstrip()}")

 227:     # If the lengths are not equal, most probably we are in decoding stage with cache
 228:     # In that case the position ids will not always start with `0` and we need a better way to infer
 229:     # cumulative seq lengths.
 230:     if query_length != kv_length:
 231:         indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
 232: 
 233:         tensor_kws = {"dtype": torch.int32, "device": position_ids.device}
 234:         last_position_ids = position_ids[:, -1]
 235: 
 236:         cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), last_position_ids.cumsum(0).add(1)], 0)
 237:         max_length_k = int(last_position_ids.max()) + 1
 238: 
 239:         batch_size, seq_len = query.shape[:2]
 240:         q_len = torch.ones(batch_size, **tensor_kws) if seq_len == 1 else last_position_ids.to(torch.int32).add(1)
 241:         cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
 242:         max_lengt

In [4]:
import transformers
transformers.__version__, transformers.__file__

('4.55.0.dev0',
 '/mnt/2tb/PyCharmProjects/kvpress/.venv/lib/python3.12/site-packages/transformers/__init__.py')

In [5]:
# Load pipeline

device = "cuda:0"
ckpt = "meta-llama/Llama-3.2-1B-Instruct"
# ckpt = "microsoft/Phi-3.5-mini-instruct"
# ckpt = "mistralai/Mistral-Nemo-Instruct-2407"
attn_implementation = "flash_attention_2"  # use "eager" for ObservedAttentionPress and "sdpa" if you can't use "flash_attention_2"
pipe = pipeline("kv-press-text-generation-debug", model=ckpt, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":attn_implementation})

Device set to use cuda:0


In [6]:
# Run the pipeline on a single question

question = "What is the book about?"
pred_answer = pipe("This book is about ghosts"* 20,
                   question=question, 
                   press=KnormPress(0.0),
                   max_new_tokens=100
                  )["answer"]
print(f"Prediction: {pred_answer}")

Prediction: (There is no number at the end of the sentence. It is a repetition of the word "book".)


In [7]:
# Load data
url = "https://en.wikipedia.org/wiki/Nvidia"
content = requests.get(url).content
soup = BeautifulSoup(content, "html.parser")
context = "".join([p.text for p in soup.find_all("p")]) + "\n\n"
tokens = pipe.tokenizer.encode(context, return_tensors="pt").to(device)
print(f"Number of tokens: {tokens.size(1)}")

Number of tokens: 9091


In [8]:
import transformers
transformers.__version__, transformers.__file__

('4.55.0.dev0',
 '/mnt/2tb/PyCharmProjects/kvpress/.venv/lib/python3.12/site-packages/transformers/__init__.py')

In [9]:
!head /mnt/2tb/PyCharmProjects/kvpress/.venv/lib/python3.12/site-packages/transformers/__init__.py

# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


# Use the pipeline with a press

In [10]:
# Pick a press with a compression ratio, you can run the following cells with different presses
compression_ratio = 0.0
press = ExpectedAttentionPress(compression_ratio)
# press = KnormPress(compression_ratio)
# press = RandomPress(compression_ratio)

In [11]:
# Run the pipeline on a single question

question = "Complete this sentence: The Nvidia GeForce Partner Program was a ..."
true_answer = "marketing program designed to provide partnering companies with benefits such as public relations support, video game bundling, and marketing development funds."
pred_answer = pipe(context, question=question, press=press)["answer"]

print(f"Question:   {question}")
print(f"Answer:     {true_answer}")
print(f"Prediction: {pred_answer}")

Question:   Complete this sentence: The Nvidia GeForce Partner Program was a ...
Answer:     marketing program designed to provide partnering companies with benefits such as public relations support, video game bundling, and marketing development funds.
Prediction: In 2025, Nvidia's CEO Jensen Huang announced that the company would be expanding its presence in the automotive industry, with a focus on developing autonomous vehicles and other mobility solutions.


Question:   Complete this sentence: The Nvidia GeForce Partner Program was a ...
Answer:     marketing program designed to provide partnering companies with benefits such as public relations support, video game bundling, and marketing development funds.
Prediction: marketing program designed to provide partnering companies with benefits such as public relations support, video game bundling, and marketing development funds.


In [12]:
# Run the pipeline on multiple questions, the context will be compressed only once

questions = [
    "What happened on March 1, 2024?",
    "What was the unofficial company motto of Nvidia during the early days?",
]

true_answers = [
    "Nvidia became the third company in the history of the United States to close with a market capitalization in excess of $2 trillion",
    "Our company is thirty days from going out of business",
]

pred_answers = pipe(context, questions=questions, press=press)["answers"]
for question, pred_answer, true_answer in zip(questions, pred_answers, true_answers):
    print(f"Question:   {question}")
    print(f"Answer:     {true_answer}")
    print(f"Prediction: {pred_answer}")
    print()

Question:   What happened on March 1, 2024?
Answer:     Nvidia became the third company in the history of the United States to close with a market capitalization in excess of $2 trillion
Prediction: In 2025, Nvidia's CEO Jensen Huang announced that the company would be expanding its presence in the automotive industry, with a focus on developing autonomous vehicles and other mobility solutions.

Question:   What was the unofficial company motto of Nvidia during the early days?
Answer:     Our company is thirty days from going out of business
Prediction: In 2025, Nvidia's CEO Jensen Huang announced that the company would be expanding its presence in the automotive industry, with a focus on developing autonomous vehicles and other mobility solutions.



In [13]:
# Use an answer prefix and limit the number of tokens in the answer

question = "What is GTC ?"
true_answer = "Nvidia's GPU Technology Conference (GTC) is a series of technical conferences held around the world."
answer_prefix = "Come on you don't know GTC ? Everyone"
max_new_tokens = 30

pred_answer_with_prefix = pipe(context, question=question, answer_prefix=answer_prefix, press=press, max_new_tokens=max_new_tokens)["answer"]
pred_answer_without_prefix = pipe(context, question=question, press=press, max_new_tokens=max_new_tokens)["answer"]

print(f"Question:              {question}")
print(f"Answer:                {true_answer}")
print(f"Prediction w/o prefix: {pred_answer_without_prefix}")
print(f"Prediction w/ prefix : {answer_prefix + pred_answer_with_prefix}")

Question:              What is GTC ?
Answer:                Nvidia's GPU Technology Conference (GTC) is a series of technical conferences held around the world.
Prediction w/o prefix: In 2025, Nvidia's market capitalization was worth more than the combined value of AMD, ARM, Broadcom, and Intel
Prediction w/ prefix : Come on you don't know GTC ? Everyone's  .


In [14]:
# SnapKV use the latest queries to prune the KV-cache. It's hence more efficient if we include the question during compression as the latest queries will correspond to the question.
# However it implies also implies that SnapKV cannot compress well the context independently of the question (e.g. as in a chat use case)


question = "Complete this sentence: In April 2016, Nvidia produced the DGX-1 based on an 8 GPU cluster,"
true_answer = (
    "to improve the ability of users to use deep learning by combining GPUs with integrated deep learning software"
)

press = SnapKVPress(compression_ratio=0.8)

pred_answer_with_question = pipe(context + question, press=press)["answer"]
pred_answer_without_question = pipe(context, question=question, press=press)["answer"]

print(f"Question:         {question}")
print(f"Answer:           {true_answer}")
print(f"Prediction w/ Q:  {pred_answer_with_question}")
print(f"Prediction w/o Q: {pred_answer_without_question}")

Question:         Complete this sentence: In April 2016, Nvidia produced the DGX-1 based on an 8 GPU cluster,
Answer:           to improve the ability of users to use deep learning by combining GPUs with integrated deep learning software
Prediction w/ Q:  The information provided is a summary of the text about Nvidia Corporation.
Prediction w/o Q: Nvidia Corporation
