Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope
from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope
from executorch.examples.models.llama.rope import (
apply_rotary_emb_to_k,
hf_apply_rotary_emb_to_k,
Expand Down Expand Up @@ -87,3 +87,122 @@ def rerotate_k(
)

return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)


class KVCacheWithAttentionSink(KVCache):
"""
KV cache that supports attention sink. It keeps the initial few tokens as attention sink.
For other tokens, it uses a sliding window to keep the most recent tokens.

Parameters:
window_size: the size of the sliding window
sink_size: the number of initial tokens to keep as attention sink
eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache
"""

def __init__(
self,
n_heads: int,
head_dim: int,
transpose_cache: bool,
enable_dynamic_shape: bool,
rope: RopeWithAttentionSink,
window_size: int,
sink_size: int,
eviction_batch_size: int,
max_batch_size: int = 1,
dtype=torch.float32,
):
super().__init__(
max_batch_size=max_batch_size,
max_seq_length=window_size + sink_size,
n_heads=n_heads,
head_dim=head_dim,
transpose_cache=transpose_cache,
enable_dynamic_shape=enable_dynamic_shape,
dtype=dtype,
)
self.rope = rope
self.window_size = window_size
self.sink_size = sink_size
self.eviction_batch_size = eviction_batch_size
self.position_shift = 0

def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
"""
Evict old tokens from the cache to make rooms for new tokens.

Parameters:
input_pos: the start position of the incoming token in the actual sequence
seq_len: the length of the incoming sequence
rope: the rope object to use for rerotating k

Returns:
the number of tokens to evict from the cache which is also the number of
positions to shift for incoming tokens
"""
input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
self.eviction_batch_size,
)
num_to_keep = (
input_pos_item + self.position_shift - self.sink_size - num_to_evict
)
num_empty_space = self.window_size - num_to_keep
dim_to_slice = 2 if self.transpose_cache else 1
k_to_keep = self.k_cache.narrow(
dim_to_slice,
self.sink_size + num_to_evict, # pyre-ignore [6]
num_to_keep, # pyre-ignore [6]
)
if self.transpose_cache:
k_to_keep = self.rope.rerotate_k(
k=k_to_keep.transpose(1, 2),
original_position=( # pyre-ignore [6]
self.sink_size + num_to_evict
),
new_position=self.sink_size,
).transpose(1, 2)
else:
k_to_keep = self.rope.rerotate_k(
k=k_to_keep,
original_position=( # pyre-ignore [6]
self.sink_size + num_to_evict
),
new_position=self.sink_size,
)
self.k_cache = torch.cat(
[
self.k_cache.narrow(dim_to_slice, 0, self.sink_size),
k_to_keep,
torch.zeros_like(
self.k_cache.narrow(
dim_to_slice, 0, num_empty_space # pyre-ignore [6]
)
),
],
dim=dim_to_slice,
)
self.v_cache = torch.cat(
[
self.v_cache.narrow(dim_to_slice, 0, self.sink_size),
self.v_cache.narrow(
dim_to_slice,
self.sink_size + num_to_evict, # pyre-ignore [6]
num_to_keep, # pyre-ignore [6]
),
torch.zeros_like(
self.v_cache.narrow(
dim_to_slice, 0, num_empty_space # pyre-ignore [6]
)
),
],
dim=dim_to_slice,
)
self.position_shift -= num_to_evict # pyre-ignore [8]
return self.position_shift
Loading