# Creating a new press

In this guide, we will walk you through the process of creating a new press.

In [1]:
import torch
from torch import nn
from transformers import pipeline

from kvpress import BasePress, KnormPress

In [2]:
# Load pipeline

device = "cuda:0"
ckpt = "Qwen/Qwen2.5-1.5B-Instruct"
attn_implementation = "flash_attention_2"
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":attn_implementation})

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [3]:
# Load data

context = "In this step-by-step guide, you will learn how to create a new press in kvpress !"
question = "\nWhat is the purpose of this guide?"
tokens = pipe.tokenizer(context, return_tensors="pt").to(device)

## 1. Understanding how press work


A press registers a forward hook to each attention layer during the pre-filling phase:
1. Immediately after the forward pass, the hook is called, and it computes a score for each key-value pair using the `press.score` method
2. The key-value pairs with the lowest scores are then removed based on the `compression_ratio` parameter

In [4]:
compression_ratio = 0.25
press = KnormPress(compression_ratio)

with torch.no_grad():
    outputs_without_press = pipe.model(**tokens, output_hidden_states=True)

with torch.no_grad(), press(pipe.model):
    output_with_press = pipe.model(**tokens)

print(f"Cache shape w/o press: {outputs_without_press.past_key_values[0][0].shape}")
print(f"Cache shape w/ press:  {output_with_press.past_key_values[0][0].shape}\n")

# The `KVPressTextGenerationPipeline` simply applies the `press` as above on the context tokens (see `_forward` method for more details).
print(pipe(context, question=question, press=press)["answer"])

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


Cache shape w/o press: torch.Size([1, 2, 20, 128])
Cache shape w/ press:  torch.Size([1, 2, 15, 128])

The purpose of this step-by-step guide is to provide instructions on how to create a new press in kvpress. The guide is designed to help users understand the process of setting up a new press in the kvpress platform.


## 2. Creating your own press


### 2.1 Updating the `score` method


The easiest way to create a new press is to create a class that inherits from `BasePress` and implement a `score` method that computes the score for each key-value pair. 

The arguments of the `score` method are obtained from the forward hook:
- `module`: the attention layer
- `hidden_states`: the input of the attention layer
- `keys` and `values`: the key-value pairs from the attention layer
- `attentions`: the attention weights, only available with `attn_implementation="eager"`

In this first example, we will reproduce the `KnormPress` where the score of a key-value pair is simply the opposite of the norm of the key vector.

In [5]:
class MyKnormPress(BasePress):
    def score(
        self,
        module: nn.Module,
        hidden_states: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        attentions: torch.Tensor,
        kwargs,
    ) -> torch.Tensor:

        scores = -keys.norm(dim=-1)

        # For demonstration, we show some details on the shape for the first layer
        if module.layer_idx == 0:
            print(f"module: {module}")
            print(f"Number of key value heads: {module.num_key_value_heads}")
            print(f"Sequence length: {hidden_states.shape[1]}")
            print()
            print(f"hidden_states shape: {hidden_states.shape}")
            print(f"keys shape:          {keys.shape}") # shape (bhnd)
            print(f"values shape:        {values.shape}") # shape (bhnd)
            print(f"score shape:         {scores.shape}") # shape (bhn)
            print()
        
        return scores


press = MyKnormPress(compression_ratio)
print(pipe(context, question=question, press=press)["answer"])

module: Qwen2FlashAttention2(
  (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
  (k_proj): Linear(in_features=1536, out_features=256, bias=True)
  (v_proj): Linear(in_features=1536, out_features=256, bias=True)
  (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
  (rotary_emb): Qwen2RotaryEmbedding()
)
Number of key value heads: 2
Sequence length: 44

hidden_states shape: torch.Size([1, 44, 1536])
keys shape:          torch.Size([1, 2, 44, 128])
values shape:        torch.Size([1, 2, 44, 128])
score shape:         torch.Size([1, 2, 44])



The purpose of this step-by-step guide is to provide instructions on how to create a new press in kvpress. The guide is designed to help users understand the process of setting up a new press in the kvpress platform.


### 2.2 Updating the `forward_hook` method 

The `forward_hook` method defined in the `BasePress` class roughly works as follows:

1. Get the scores
2. Update the key-value pairs based on the scores and the `compression_ratio`

While we generally do not recommend to modify this method, the following example will show how it works. We will re-implement the `StreamingLLMPress` without using the `compression_ratio` parameter. In `StreamingLLM`, only the first `n_first` and last `n_last` key-value pairs are kept.

In [6]:
class MyStreamingLLMPress(BasePress):

    def __init__(self, n_first=1, n_last=8):
        self.n_first = n_first
        self.n_last = n_last

    def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):

        # Get the cache (transformers.cache_utils.DynamicCache object)
        cache = output[-1]
        i = module.layer_idx
        keys, values = cache.key_cache[i], cache.value_cache[i]

        # Update the cache to only keep the first and last tokens
        mask = torch.ones(keys.shape[-2], dtype=torch.bool, device=keys.device)
        mask[self.n_first : -self.n_last] = False
        cache.key_cache[i] = keys[:, :, mask, :]
        cache.value_cache[i] = values[:, :, mask, :]

        # Return the updated output (output[-1] has been modified in-place)
        return output


for n_last in [2, 4, 8]:
    press = MyStreamingLLMPress(n_last=n_last)
    print(f"\nn_last: {n_last}")
    print(f"Last tokens seen by the model: {pipe.tokenizer.decode(tokens.input_ids[0, -n_last:])}")
    print(f"Answer: {pipe(context, question=question, press=press)['answer']}")


n_last: 2
Last tokens seen by the model: press !


Answer: The purpose of this guide is to provide instructions and information on how to use the software or application called "Pulse" or "Pulse 2". Pulse is a popular music production software that allows users to create, edit, and mix music tracks

n_last: 4
Last tokens seen by the model:  in kvpress !
Answer: The purpose of this guide is to provide instructions on how to create a new content management system (CMS) called KVPress. KVPress is a content management system that allows users to easily create, edit, and publish content on their website. The guide

n_last: 8
Last tokens seen by the model:  create a new press in kvpress !
Answer: The purpose of this guide is to provide instructions on how to create a new press in kvpress, a software tool for managing and publishing content. The guide likely covers topics such as setting up the press, configuring settings, adding content, and publishing articles


## 3. Contributing to kvpress

All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py). We recommend not to update the `forward_hook` or `__call__` method unless necessary.