From 47dcd57506d0d21c02c34bd666f7f3d271d27043 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Wed, 30 Oct 2024 14:30:49 -0700 Subject: [PATCH 1/2] add attention_sink.py This PR adds `KVCacheWithAttentionSink`, which is required for `AttentionSink`. It keeps the first `sink_size` tokens as attention sinks and maintains a sliding window with `window_size` for new tokens. Note: I am trying to implement and verify `AttentionSink` in eager mode first. So the current implementation may still have some lower errors or performance issue. For example, it does not support the case when dynamic shape is disabled. Will leave these problems to resolve when we are ready to deploy `AttentionSink` to edge. Differential Revision: [D65235798](https://our.internmc.facebook.com/intern/diff/D65235798/) [ghstack-poisoned] --- examples/models/llama/TARGETS | 12 ++ .../source_transformation/attention_sink.py | 114 +++++++++++ .../test_attention_sink.py | 178 ++++++++++++++++++ 3 files changed, 304 insertions(+) create mode 100644 examples/models/llama/source_transformation/attention_sink.py create mode 100644 examples/models/llama/source_transformation/test_attention_sink.py diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index d328adffbf7..02aed580224 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -92,6 +92,7 @@ runtime.python_library( "source_transformation/sdpa.py", "source_transformation/spin_quant.py", "source_transformation/vulkan_rope.py", + "source_transformation/attention_sink.py", ], _is_external_target = True, base_module = "executorch.examples.models.llama", @@ -212,3 +213,14 @@ runtime.python_test( "//executorch/examples/models/llama:llama_transformer", ], ) + +runtime.python_test( + name = "attention_sink_test", + srcs = [ + "source_transformation/test_attention_sink.py", + ], + deps = [ + "//caffe2:torch", + ":export_library", + ], +) diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py new file mode 100644 index 00000000000..856c4f70f50 --- /dev/null +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Components for supporting Attention Sink. See +# https://arxiv.org/abs/2309.17453 for more details about Attention Sink. + +from typing import Tuple + +import torch + +from torch import nn + + +class KVCacheWithAttentionSink(nn.Module): + """ + KV cache that supports attention sink. It keeps the initial few tokens as attention sink. + For other tokens, it uses a sliding window to keep the most recent tokens. + + Parameters: + window_size: the size of the sliding window + sink_size: the number of initial tokens to keep as attention sink + """ + + def __init__( + self, + max_batch_size: int, + window_size: int, + sink_size: int, + n_heads: int, + head_dim: int, + transpose_cache: bool, + dtype=torch.float32, + ): + super().__init__() + self.window_size = window_size + self.sink_size = sink_size + self.cache_size = window_size + sink_size + self.is_transposed = transpose_cache + if transpose_cache: + cache_shape = (max_batch_size, n_heads, self.cache_size, head_dim) + else: + cache_shape = (max_batch_size, self.cache_size, n_heads, head_dim) + + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.transpose_cache = transpose_cache + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + dim_to_slice = 2 if self.transpose_cache else 1 + seq_length = k_val.size(dim_to_slice) + + if start_pos + seq_length <= self.cache_size: + # There are still enough spaces in the cache to store the new tokens. + # No need to shift the existing tokens. + # pyre-ignore: Incompatible parameter type [6] + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) + + narrowed_k.copy_(k_val) + narrowed_v.copy_(v_val) + else: + # There are not enough spaces in the cache to store the new tokens. + # We need to shift the existing tokens. + num_to_evict = min(start_pos + seq_length - self.cache_size, seq_length) + + # Shift the existing entries to the left + # pyre-ignore: Incompatible parameter type [6] + k_to_keep = self.k_cache.narrow( + dim_to_slice, + self.sink_size + num_to_evict, + self.window_size - num_to_evict, + ).clone() + # pyre-ignore: Incompatible parameter type [6] + v_to_keep = self.v_cache.narrow( + dim_to_slice, + self.sink_size + num_to_evict, + self.window_size - num_to_evict, + ).clone() + # pyre-ignore: Incompatible parameter type [6] + k_new_position = self.k_cache.narrow( + dim_to_slice, self.sink_size, self.window_size - num_to_evict + ) + # pyre-ignore: Incompatible parameter type [6] + v_new_position = self.v_cache.narrow( + dim_to_slice, self.sink_size, self.window_size - num_to_evict + ) + k_new_position.copy_(k_to_keep) + v_new_position.copy_(v_to_keep) + + # Appending new entries + narrowed_k = self.k_cache.narrow( + dim_to_slice, self.cache_size - seq_length, seq_length + ) + narrowed_v = self.v_cache.narrow( + dim_to_slice, self.cache_size - seq_length, seq_length + ) + narrowed_k.copy_(k_val) + narrowed_v.copy_(v_val) + return self.k_cache, self.v_cache diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py new file mode 100644 index 00000000000..b6604f619ad --- /dev/null +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + KVCacheWithAttentionSink, +) + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + + def _init_cache(self): + self.kv_cache = KVCacheWithAttentionSink( + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=self.sink_size, + n_heads=self.n_heads, + head_dim=self.head_dim, + transpose_cache=self.transpose_cache, + dtype=self.dtype, + ) + + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.window_size = 28 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 16 + self.transpose_cache = False + self.dtype = torch.float32 + self._init_cache() + + def test_update_empty_cache(self): + # KV cache is empty, update will fill sink tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k = torch.ones((1, 1, 8, 16), dtype=self.dtype) + v = torch.ones((1, 1, 8, 16), dtype=self.dtype) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + expected_k_out = torch.cat( + [ + torch.ones((1, 1, 8, 16), dtype=self.dtype), + torch.zeros((1, 31, 8, 16), dtype=self.dtype), + ], + dim=1, + ) + expected_v_out = torch.cat( + [ + torch.ones((1, 1, 8, 16), dtype=self.dtype), + torch.zeros((1, 31, 8, 16), dtype=self.dtype), + ], + dim=1, + ) + + torch.testing.assert_close(k_out, expected_k_out) + torch.testing.assert_close(v_out, expected_v_out) + + def test_update_without_shift(self): + # KV cache has enough spaces for new tokens, no shift + input_pos = torch.tensor([0], dtype=torch.int32) + k = torch.ones((1, 5, 8, 16), dtype=self.dtype) + v = torch.ones((1, 5, 8, 16), dtype=self.dtype) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k = torch.full((1, 5, 8, 16), 2, dtype=self.dtype) + v = torch.full((1, 5, 8, 16), 2, dtype=self.dtype) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + expected_k_out = torch.cat( + [ + torch.ones((1, 5, 8, 16), dtype=self.dtype), + torch.full((1, 5, 8, 16), 2, dtype=self.dtype), + torch.zeros((1, 22, 8, 16), dtype=self.dtype), + ], + dim=1, + ) + expected_v_out = torch.cat( + [ + torch.ones((1, 5, 8, 16), dtype=self.dtype), + torch.full((1, 5, 8, 16), 2, dtype=self.dtype), + torch.zeros((1, 22, 8, 16), dtype=self.dtype), + ], + dim=1, + ) + + torch.testing.assert_close(k_out, expected_k_out) + torch.testing.assert_close(v_out, expected_v_out) + + def test_update_with_some_shift(self): + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k = torch.ones((1, 5, 8, 16), dtype=self.dtype) + v = torch.ones((1, 5, 8, 16), dtype=self.dtype) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k = torch.full((1, 5, 8, 16), 2, dtype=self.dtype) + v = torch.full((1, 5, 8, 16), 2, dtype=self.dtype) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([10], dtype=torch.int32) + k = torch.full((1, 24, 8, 16), 3, dtype=self.dtype) + v = torch.full((1, 24, 8, 16), 3, dtype=self.dtype) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + expected_k_out = torch.cat( + [ + torch.ones((1, 4, 8, 16), dtype=self.dtype), + torch.full((1, 4, 8, 16), 2, dtype=self.dtype), + torch.full((1, 24, 8, 16), 3, dtype=self.dtype), + ], + dim=1, + ) + expected_v_out = torch.cat( + [ + torch.ones((1, 4, 8, 16), dtype=self.dtype), + torch.full((1, 4, 8, 16), 2, dtype=self.dtype), + torch.full((1, 24, 8, 16), 3, dtype=self.dtype), + ], + dim=1, + ) + + torch.testing.assert_close(k_out, expected_k_out) + torch.testing.assert_close(v_out, expected_v_out) + + def test_update_with_all_shift(self): + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k = torch.ones((1, 5, 8, 16), dtype=self.dtype) + v = torch.ones((1, 5, 8, 16), dtype=self.dtype) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k = torch.full((1, 28, 8, 16), 2, dtype=self.dtype) + v = torch.full((1, 28, 8, 16), 2, dtype=self.dtype) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([33], dtype=torch.int32) + k = torch.full((1, 6, 8, 16), 3, dtype=self.dtype) + v = torch.full((1, 6, 8, 16), 3, dtype=self.dtype) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + expected_k_out = torch.cat( + [ + torch.ones((1, 4, 8, 16), dtype=self.dtype), + torch.full((1, 22, 8, 16), 2, dtype=self.dtype), + torch.full((1, 6, 8, 16), 3, dtype=self.dtype), + ], + dim=1, + ) + expected_v_out = torch.cat( + [ + torch.ones((1, 4, 8, 16), dtype=self.dtype), + torch.full((1, 22, 8, 16), 2, dtype=self.dtype), + torch.full((1, 6, 8, 16), 3, dtype=self.dtype), + ], + dim=1, + ) + + torch.testing.assert_close(k_out, expected_k_out) + torch.testing.assert_close(v_out, expected_v_out) From 7140decd8a3f79ad8c0ee05c4205b457bdfde523 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Mon, 4 Nov 2024 15:17:47 -0800 Subject: [PATCH 2/2] Update on "add attention_sink.py" This PR adds `KVCacheWithAttentionSink`, which is required for `AttentionSink`. It keeps the first `sink_size` tokens as attention sinks and maintains a sliding window with `window_size` for new tokens. Note: I am trying to implement and verify `AttentionSink` in eager mode first. So the current implementation may still have some lower errors or performance issue. For example, it does not support the case when dynamic shape is disabled. Will leave these problems to resolve when we are ready to deploy `AttentionSink` to edge. Differential Revision: [D65235798](https://our.internmc.facebook.com/intern/diff/D65235798/) [ghstack-poisoned] --- .../source_transformation/attention_sink.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 856c4f70f50..0debf442742 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -79,25 +79,25 @@ def update( num_to_evict = min(start_pos + seq_length - self.cache_size, seq_length) # Shift the existing entries to the left - # pyre-ignore: Incompatible parameter type [6] k_to_keep = self.k_cache.narrow( dim_to_slice, - self.sink_size + num_to_evict, - self.window_size - num_to_evict, + self.sink_size + num_to_evict, # pyre-ignore [6] + self.window_size - num_to_evict, # pyre-ignore [6] ).clone() - # pyre-ignore: Incompatible parameter type [6] v_to_keep = self.v_cache.narrow( dim_to_slice, - self.sink_size + num_to_evict, - self.window_size - num_to_evict, + self.sink_size + num_to_evict, # pyre-ignore [6] + self.window_size - num_to_evict, # pyre-ignore [6] ).clone() - # pyre-ignore: Incompatible parameter type [6] k_new_position = self.k_cache.narrow( - dim_to_slice, self.sink_size, self.window_size - num_to_evict + dim_to_slice, + self.sink_size, + self.window_size - num_to_evict, # pyre-ignore [6] ) - # pyre-ignore: Incompatible parameter type [6] v_new_position = self.v_cache.narrow( - dim_to_slice, self.sink_size, self.window_size - num_to_evict + dim_to_slice, + self.sink_size, + self.window_size - num_to_evict, # pyre-ignore [6] ) k_new_position.copy_(k_to_keep) v_new_position.copy_(v_to_keep)