# Creating a new press

In this guide, we will walk you through the process of creating a new press.

In [1]:
from dataclasses import dataclass
from contextlib import contextmanager

import torch
from torch import nn
from transformers import pipeline

from kvpress import BasePress, KnormPress, ScorerPress

In [2]:
# Load pipeline

device = "cuda:0"
ckpt = "Qwen/Qwen2.5-1.5B-Instruct"
attn_implementation = "flash_attention_2"
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":attn_implementation})

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Device set to use cuda:0


In [3]:
# Load data

context = "In this step-by-step guide, you will learn how to create a new press in kvpress !"
question = "\nWhat is the purpose of this guide?"
tokens = pipe.tokenizer(context, return_tensors="pt").to(device)

## 1. Understanding how press work

A press registers a forward hook to each attention layer during the pre-filling phase.  Immediately after the forward pass, the hook is called, and it compresses the KV cache.

In [4]:
compression_ratio = 0.25
press = KnormPress(compression_ratio)

with torch.no_grad():
    outputs_without_press = pipe.model(**tokens, output_hidden_states=True)

with torch.no_grad(), press(pipe.model):
    output_with_press = pipe.model(**tokens)

print(f"Cache shape w/o press: {outputs_without_press.past_key_values[0][0].shape}")
print(f"Cache shape w/ press:  {output_with_press.past_key_values[0][0].shape}\n")

# The `KVPressTextGenerationPipeline` simply applies the `press` as above on the context tokens (see `_forward` method for more details).
print(pipe(context, question=question, press=press)["answer"])

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


Cache shape w/o press: torch.Size([1, 2, 20, 128])
Cache shape w/ press:  torch.Size([1, 2, 15, 128])

The purpose of this step-by-step guide is to provide instructions on how to create a new press in kvpress. The guide is designed to help users understand the process of setting up a new press in the kvpress platform.


## 2. Creating your own press


### 2.1 Updating the `score` method

The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair.

The arguments of the `score` method are obtained from the forward hook:
- `module`: the attention layer
- `hidden_states`: the input of the attention layer
- `keys` and `values`: the key-value pairs from the attention layer
- `attentions`: the attention weights, only available with `attn_implementation="eager"`

In this first example, we will reproduce the `KnormPress` where the score of a key-value pair is simply the opposite of the norm of the key vector.

In [None]:
class MyKnormPress(ScorerPress):
    def score(
        self,
        module: nn.Module,
        hidden_states: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        attentions: torch.Tensor,
        kwargs,
    ) -> torch.Tensor:

        scores = -keys.norm(dim=-1)

        # For demonstration, we show some details on the shape for the first layer
        if module.layer_idx == 0:
            print(f"module: {module}")
            print(f"Number of key value heads: {module.config.num_key_value_heads}")
            print(f"Sequence length: {hidden_states.shape[1]}")
            print()
            print(f"hidden_states shape: {hidden_states.shape}")
            print(f"keys shape:          {keys.shape}") # shape (bhnd)
            print(f"values shape:        {values.shape}") # shape (bhnd)
            print(f"score shape:         {scores.shape}") # shape (bhn)
            print()
        
        return scores


press = MyKnormPress(compression_ratio)
print(pipe(context, question=question, press=press)["answer"])

### 2.2 Updating the `compress` method 

The `compress` method defined in the `BasePress` contains the core logic of the compression and returns compressed keys and values. For instance, in the `ScorerPress` the `compress` calls the `score` method (which is specific to `ScorerPress`) and prune the key-value pairs based on the scores.

The following example will show how it works. We will re-implement the `StreamingLLMPress` in a more compact way.

In [None]:
@dataclass
class MyStreamingLLMPress(BasePress):
    n_first: int = 1
    n_last: int = 8

    def compress(
        self,
        module: nn.Module,
        hidden_states: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        attentions: torch.Tensor,
        kwargs: dict,
    ) -> tuple[torch.Tensor, torch.Tensor]:

        mask = torch.ones(keys.shape[-2], dtype=torch.bool, device=keys.device)
        mask[self.n_first : -self.n_last] = False
        return keys[:, :, mask, :], values[:, :, mask, :]


for n_last in [2, 4, 8]:
    press = MyStreamingLLMPress(n_last=n_last)
    print(f"\nn_last: {n_last}")
    print(f"Last tokens seen by the model: {pipe.tokenizer.decode(tokens.input_ids[0, -n_last:])}")
    print(f"Answer: {pipe(context, question=question, press=press)['answer']}")

Note that in the `compress` method is itself used in the `forward_hook` method which ensures quantization is handled properly and that the compression is only performed during prefilling. While we don't recommend to change the `forward_hook` method directly, you can still modify it if you need to !

### 2.3 Head-wise compression

Since 0.2.0, kvpress support head-wise compression, where the KV cache of each head might be compressed by a different compression ratio. 

To achieve proper head-wise compression, one should implement a new kernel for attention along with a custom cache class. Instead, the current implementation fakes head-wise compression by updating the pruned keys by a fake key so that the output of the attention layer is not affected. This is implemented through `kvpress.attention_patch.patch_attention_functions`.

To implement a method that compresses the KV cache head-wise, one should instantiate the `masked_key_indices` as outlined below.

In [4]:
@dataclass
class RandomHeadPress(BasePress):

    compression_ratio: float = 0.0

    def compress(self, module, hidden_states, keys, values, attentions, kwargs):
        assert keys.shape[0] == 1, "Only batch size 1 is supported"
        scores = torch.rand(keys.shape[:-1], device=keys.device)
        mask = scores < torch.quantile(scores, self.compression_ratio)
        module.masked_key_indices = torch.nonzero(mask, as_tuple=True)
        
        return keys, values

for compression_ratio in [0, 0.25, 0.9]:
    press = RandomHeadPress(compression_ratio)
    print(f"\ncompression_ratio: {compression_ratio}")
    print(f"Answer: {pipe(context, question=question, press=press)['answer']}")


compression_ratio: 0


Answer: The purpose of this step-by-step guide is to provide a comprehensive and easy-to-follow tutorial on how to create a new press in the KVPress platform. The guide is designed to help users understand the process of setting up a new press, including the

compression_ratio: 0.25
Answer: The purpose of this guide is to provide a step-by-step process for creating a new press in KVPRESS, which is a popular open-source web server. The guide will cover the necessary steps to set up and configure a new press, including installing

compression_ratio: 0.9
Answer: This guide is not a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a


## 3. Contributing to kvpress

All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to 
- register it in the `__init__.py` file of repository
- register the press in [default_presses.py](tests/default_presses.py)
- update the README