diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 8f4fd1ebd25..8450600d2b1 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -11,7 +11,7 @@ import torch -from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope +from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, @@ -87,3 +87,122 @@ def rerotate_k( ) return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) + + +class KVCacheWithAttentionSink(KVCache): + """ + KV cache that supports attention sink. It keeps the initial few tokens as attention sink. + For other tokens, it uses a sliding window to keep the most recent tokens. + + Parameters: + window_size: the size of the sliding window + sink_size: the number of initial tokens to keep as attention sink + eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache + """ + + def __init__( + self, + n_heads: int, + head_dim: int, + transpose_cache: bool, + enable_dynamic_shape: bool, + rope: RopeWithAttentionSink, + window_size: int, + sink_size: int, + eviction_batch_size: int, + max_batch_size: int = 1, + dtype=torch.float32, + ): + super().__init__( + max_batch_size=max_batch_size, + max_seq_length=window_size + sink_size, + n_heads=n_heads, + head_dim=head_dim, + transpose_cache=transpose_cache, + enable_dynamic_shape=enable_dynamic_shape, + dtype=dtype, + ) + self.rope = rope + self.window_size = window_size + self.sink_size = sink_size + self.eviction_batch_size = eviction_batch_size + self.position_shift = 0 + + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + """ + Evict old tokens from the cache to make rooms for new tokens. + + Parameters: + input_pos: the start position of the incoming token in the actual sequence + seq_len: the length of the incoming sequence + rope: the rope object to use for rerotating k + + Returns: + the number of tokens to evict from the cache which is also the number of + positions to shift for incoming tokens + """ + input_pos_item = input_pos.item() + torch._check_is_size(input_pos_item) + if input_pos_item + self.position_shift + seq_len > self.max_seq_length: + # There are not enough spaces in the cache to store the new tokens. + # We need to evict some old tokens and shift some recent tokens. + num_to_evict = max( + input_pos_item + self.position_shift - self.max_seq_length + seq_len, + self.eviction_batch_size, + ) + num_to_keep = ( + input_pos_item + self.position_shift - self.sink_size - num_to_evict + ) + num_empty_space = self.window_size - num_to_keep + dim_to_slice = 2 if self.transpose_cache else 1 + k_to_keep = self.k_cache.narrow( + dim_to_slice, + self.sink_size + num_to_evict, # pyre-ignore [6] + num_to_keep, # pyre-ignore [6] + ) + if self.transpose_cache: + k_to_keep = self.rope.rerotate_k( + k=k_to_keep.transpose(1, 2), + original_position=( # pyre-ignore [6] + self.sink_size + num_to_evict + ), + new_position=self.sink_size, + ).transpose(1, 2) + else: + k_to_keep = self.rope.rerotate_k( + k=k_to_keep, + original_position=( # pyre-ignore [6] + self.sink_size + num_to_evict + ), + new_position=self.sink_size, + ) + self.k_cache = torch.cat( + [ + self.k_cache.narrow(dim_to_slice, 0, self.sink_size), + k_to_keep, + torch.zeros_like( + self.k_cache.narrow( + dim_to_slice, 0, num_empty_space # pyre-ignore [6] + ) + ), + ], + dim=dim_to_slice, + ) + self.v_cache = torch.cat( + [ + self.v_cache.narrow(dim_to_slice, 0, self.sink_size), + self.v_cache.narrow( + dim_to_slice, + self.sink_size + num_to_evict, # pyre-ignore [6] + num_to_keep, # pyre-ignore [6] + ), + torch.zeros_like( + self.v_cache.narrow( + dim_to_slice, 0, num_empty_space # pyre-ignore [6] + ) + ), + ], + dim=dim_to_slice, + ) + self.position_shift -= num_to_evict # pyre-ignore [8] + return self.position_shift diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index 8eaa992dc38..4ffecf1e9c3 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -10,6 +10,7 @@ from executorch.examples.models.llama.llama_transformer import ModelArgs from executorch.examples.models.llama.source_transformation.attention_sink import ( + KVCacheWithAttentionSink, RopeWithAttentionSink, ) from parameterized import parameterized @@ -79,14 +80,10 @@ def test_get_freqs( def test_rotate(self, original_position, new_position): seq_len = 32 - q = torch.rand( - 1, seq_len, self.params.n_heads, self.params.head_dim, dtype=torch.float32 - ) + size = (1, seq_len, self.params.n_heads, self.params.head_dim) + q = torch.rand(*size, dtype=torch.float32) k = torch.rand( - 1, - seq_len, - self.params.n_heads, - self.params.head_dim, + *size, dtype=torch.float32, ) freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( @@ -118,3 +115,465 @@ def test_rotate(self, original_position, new_position): ) torch.testing.assert_close(rerotated_k, expected_k) + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + + _single_evict_test_cases = [ + [False, 4, 1], + [True, 4, 1], + ] + + _batch_evict_test_cases = [ + [False, 4, 8], + [True, 4, 8], + ] + + _sliding_window_test_cases = [ + [False, 0, 1], + [True, 0, 1], + ] + + def _init_cache(self, transpose_cache, sink_size, eviction_batch_size): + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_seq_len=self.window_size + sink_size, + ) + self.rope_with_attention_sink = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.params.n_heads, + head_dim=self.params.head_dim, + transpose_cache=transpose_cache, + enable_dynamic_shape=self.params.enable_dynamic_shape, + rope=self.rope_with_attention_sink, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=self.dtype, + ) + + def _rand_kv_with_length(self, transpose_cache, seq_len): + size = ( + ( + self.max_batch_size, + seq_len, + self.params.n_heads, + self.params.head_dim, + ) + if not transpose_cache + else ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + ) + if not transpose_cache: + k = torch.rand( + *size, + dtype=self.dtype, + ) + v = torch.rand( + *size, + dtype=self.dtype, + ) + else: + k = torch.rand( + *size, + dtype=self.dtype, + ) + v = torch.rand( + *size, + dtype=self.dtype, + ) + return k, v + + def _zero_kv_with_length(self, transpose_cache, seq_len): + size = ( + ( + self.max_batch_size, + seq_len, + self.params.n_heads, + self.params.head_dim, + ) + if not transpose_cache + else ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + ) + if not transpose_cache: + k = torch.zeros( + *size, + dtype=self.dtype, + ) + v = torch.zeros( + *size, + dtype=self.dtype, + ) + else: + k = torch.zeros( + *size, + dtype=self.dtype, + ) + v = torch.zeros( + *size, + dtype=self.dtype, + ) + return k, v + + def _get_dim_to_slice(self, transpose_cache): + return 2 if transpose_cache else 1 + + def _get_expected_rotated_k( + self, transpose_cache, k, original_position, new_position + ): + if transpose_cache: + return self.rope_with_attention_sink.rerotate_k( + k=k.transpose(1, 2), + original_position=original_position, + new_position=new_position, + ).transpose(1, 2) + else: + return self.rope_with_attention_sink.rerotate_k( + k=k, original_position=original_position, new_position=new_position + ) + + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.window_size = 28 + self.dtype = torch.float32 + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_empty_cache(self, transpose_cache, sink_size, eviction_batch_size): + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache is empty, evict does nothing + input_pos = torch.tensor([0], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + expected_k, expected_v = self._zero_kv_with_length( + transpose_cache, self.window_size + sink_size + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_without_shift(self, transpose_cache, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has enough spaces for new tokens, no shift + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 10) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + zero_k, zero_v = self._zero_kv_with_length( + transpose_cache, self.window_size + sink_size - 10 + ) + + expected_k = torch.cat( + [ + k, + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_some_shift( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 24) == -2 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 24) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 1, 4), 6, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 1, 4), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_all_shift( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 27) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([32], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 6) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 5, 22), 10, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 5, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_some_shift_for_sliding_window( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 20) == -2 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 20) + expected_k = torch.cat( + [ + self._get_expected_rotated_k( + transpose_cache, k.narrow(dimension_to_slice, 2, 3), 2, 0 + ), + self._get_expected_rotated_k(transpose_cache, k1, 5, 3), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 2, 3), + v1, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_all_shift_for_sliding_window( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 23) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([28], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 6) + expected_k = torch.cat( + [ + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 1, 22), 6, 0 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v1.narrow(dimension_to_slice, 1, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_seq_len( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 12) == -10 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 12) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 9, 16), 14, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 9, 16), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_batch_size( + self, transpose_cache, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice(transpose_cache) + + self._init_cache(transpose_cache, sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(transpose_cache, 5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(transpose_cache, 25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -8 + + zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 10) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + transpose_cache, k1.narrow(dimension_to_slice, 7, 18), 12, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 7, 18), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v)