diff --git a/extension/llm/modules/__init__.py b/extension/llm/modules/__init__.py index 38245bf9353..49c141e761a 100644 --- a/extension/llm/modules/__init__.py +++ b/extension/llm/modules/__init__.py @@ -8,8 +8,13 @@ replace_tile_positional_embedding, TilePositionalEmbedding, ) +from .attention import MultiHeadAttention, replace_mha_with_inference_mha +from .kv_cache import KVCache __all__ = [ "TilePositionalEmbedding", "replace_tile_positional_embedding", + "MultiHeadAttention", + "replace_mha_with_inference_mha", + "KVCache", ] diff --git a/extension/llm/modules/mha.py b/extension/llm/modules/attention.py similarity index 95% rename from extension/llm/modules/mha.py rename to extension/llm/modules/attention.py index 0bfa4eb20ce..74e14076b37 100644 --- a/extension/llm/modules/mha.py +++ b/extension/llm/modules/attention.py @@ -9,6 +9,7 @@ import torch import torchtune.modules.attention as TorchTuneAttention +from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache @@ -148,7 +149,6 @@ def __init__( num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, head_dim=self.head_dim, - q_per_kv=self.num_heads // self.num_kv_heads, attn_dropout=self.attn_dropout if self.training else 0.0, is_causal=self.is_causal, attention_fn=self._attention_call, @@ -177,12 +177,13 @@ def setup_cache( "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." ) else: - self.kv_cache = KVCache( + self.kv_cache = InferenceKVCache( batch_size=batch_size, max_seq_len=max_seq_len, - num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, dtype=dtype, + transpose_cache=False, ) self._sdpa.kv_cache = self.kv_cache self.cache_enabled = True @@ -307,7 +308,6 @@ def __init__( num_kv_heads: int, num_heads: int, head_dim: int, - q_per_kv: int, attn_dropout: float, is_causal: bool, attention_fn, @@ -317,7 +317,7 @@ def __init__( self.num_kv_heads = num_kv_heads self.num_heads = num_heads self.head_dim = head_dim - self.q_per_kv = q_per_kv + self.q_per_kv = self.num_heads // self.num_kv_heads self.attn_dropout = attn_dropout self.is_causal = is_causal self._attention_fn = attention_fn @@ -330,25 +330,25 @@ def forward( v: torch.Tensor, # [b, s, n_kv, h_d] bsz: int, seq_len: int, - mask: torch.Tensor = None, + mask: Optional[_MaskType] = None, ) -> torch.Tensor: # View + expand + reshape bring num_kv_heads to num_heads for k and v # to match q. # k: [bsz, seq_len, n_kv, 1, h_d] # v: [bsz, seq_len, n_kv, 1, h_d] - k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) - v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + k = k.view(bsz, -1, self.num_kv_heads, 1, self.head_dim) + v = v.view(bsz, -1, self.num_kv_heads, 1, self.head_dim) # Expand the key and value tensors to have the same shape # as the query tensor by copying values across the relevant dim if self.num_heads != self.num_kv_heads: - k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) - v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + k = k.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim) + v = v.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim) # [bsz, s, n_h, h_d] - k = k.reshape(bsz, seq_len, -1, self.head_dim) - v = v.reshape(bsz, seq_len, -1, self.head_dim) + k = k.reshape(bsz, -1, self.num_heads, self.head_dim) + v = v.reshape(bsz, -1, self.num_heads, self.head_dim) # [bsz, n_h, s, h_d] q = q.transpose(1, 2) diff --git a/extension/llm/modules/kv_cache.py b/extension/llm/modules/kv_cache.py new file mode 100644 index 00000000000..827078a40a8 --- /dev/null +++ b/extension/llm/modules/kv_cache.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from torchtune.modules.kv_cache import KVCache as TuneKVCache + + +class KVCache(TuneKVCache): + """ + An export-friendly KVCache implementation adopted from torchtune KVCache: + https://github.com/pytorch/torchtune/blob/main/torchtune/modules/kv_cache.py + This also takes both transposed and un-transposed KVCache shapes. + Standalone ``nn.Module`` containing a kv-cache to cache past key and values during inference. + + Args: + batch_size (int): batch size model will be run with + max_seq_len (int): maximum sequence length model will be run with + num_kv_heads (int): number of key/value heads. + head_dim (int): per-attention head embedding dimension + dtype (torch.dtype): dtype for the caches + transpose_cache (bool): whether we transpose(1, 2) for kv cache. + """ + + def __init__( + self, + batch_size: int, + max_seq_len: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype, + transpose_cache: bool = True, + ) -> None: + super().__init__( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + dtype=dtype, + ) + self.transpose_cache = transpose_cache + self.max_seq_len = max_seq_len + if self.transpose_cache: + cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim) + else: + cache_shape = (batch_size, max_seq_len, num_kv_heads, head_dim) + + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False + ) + self.register_buffer( + "cache_pos", torch.arange(0, self.max_seq_len), persistent=False + ) + self.batch_size = batch_size + + def update( + self, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache. + + Note: + When updating the KV cache, it is assumed that subsequent updates should update key-value + positions in consecutive sequence positions. If you wish to update cache values which have + already been filled, use ``.reset()``, which will reset the cache to the zero-th position. + + Example: + >>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16) + >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32)) + >>> cache.update(keys, values) + >>> # now positions 0 through 7 are filled + >>> cache.size + >>> 8 + >>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32)) + >>> cache.update(keys, values) + >>> # this will fill at position 8 + >>> cache.size + >>> 9 + + Args: + k_val (torch.Tensor): Current key tensor with shape [B, H, S, D] + v_val (torch.Tensor): Current value tensor with shape [B, H, S, D] + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. + + Raises: + AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. + ValueError: if the batch size of the new key (or value) tensor is greater than the batch size + used during cache setup. + """ + if self.transpose_cache: + bsz, _, seq_len, _ = k_val.shape + else: + bsz, seq_len, _, _ = k_val.shape + if bsz > self.k_cache.shape[0]: + raise ValueError( + f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}" + f", but found new key tensors with batch size {k_val.shape[0]}!" + ) + + assert ( + self.cache_pos[0] + seq_len + ) <= self.max_seq_len, f"self.cache_pos[0]: {self.cache_pos[0]} + seq_len: {seq_len} > self.max_seq_len: {self.max_seq_len}" + k_out = self.k_cache + v_out = self.v_cache + + if self.transpose_cache: + k_out[:, :, self.cache_pos[:seq_len]] = k_val + v_out[:, :, self.cache_pos[:seq_len]] = v_val + else: + k_out[:, self.cache_pos[:seq_len]] = k_val + v_out[:, self.cache_pos[:seq_len]] = v_val + + # forward cache_pos seq_len positions along + # cache_pos starts at (0, 1, 2, 3, 4, 5, ...) + # an update of seq_len = 5 tokens brings it to + # (5, 6, 7, 8, 9, ...) + # this allows us to track the current position in the cache + # after the last update in a compile-friendly way without any dynamism + # e.g. relying on an int size tracker, or re-creating cache_pos every time + self.cache_pos.add_(seq_len) + + return k_out, v_out diff --git a/extension/llm/modules/test/test_mha.py b/extension/llm/modules/test/test_attention.py similarity index 70% rename from extension/llm/modules/test/test_mha.py rename to extension/llm/modules/test/test_attention.py index 0dc7cba6858..9ae136a2137 100644 --- a/extension/llm/modules/test/test_mha.py +++ b/extension/llm/modules/test/test_attention.py @@ -9,7 +9,7 @@ import torch from executorch.exir import EdgeCompileConfig, to_edge -from executorch.extension.llm.modules.mha import ( +from executorch.extension.llm.modules.attention import ( MultiHeadAttention as ETMultiHeadAttention, ) from executorch.runtime import Runtime @@ -82,10 +82,12 @@ def setUp(self): # Common inputs. seq_len = 10 self.x = torch.randn(1, seq_len, self.embed_dim) + self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len] seq_len_dim = torch.export.Dim("seq_len", min=1, max=100) self.dynamic_shapes = ( {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, + {0: torch.export.Dim.STATIC, 1: seq_len_dim}, ) def test_attention_eager(self): @@ -94,25 +96,46 @@ def test_attention_eager(self): self.assertTrue(torch.allclose(et_res, tt_res)) - # TODO: KV cache. - # self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20) - # self.tt_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20) + # test with kv cache + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) - # et_res = self.et_mha(self.x, self.x) # Self attention. - # tt_res = self.tt_mha(self.x, self.x) # Self attention. + et_res = self.et_mha(self.x, self.x) # Self attention. + tt_res = self.tt_mha(self.x, self.x) # Self attention. + + self.assertTrue(torch.allclose(et_res, tt_res)) + self.et_mha.reset_cache() + self.tt_mha.reset_cache() - # self.assertTrue(torch.allclose(et_res, tt_res)) + et_res = self.et_mha( + self.x, self.x, input_pos=self.input_pos + ) # Self attention with input pos. + tt_res = self.tt_mha( + self.x, self.x, input_pos=self.input_pos + ) # Self attention with input pos. + + self.assertTrue(torch.allclose(et_res, tt_res)) + + # test kv cache read. Input pos can be [10, 11, ..., 19] + next_input_pos = torch.arange(10, 20).unsqueeze(0) + et_res = self.et_mha( + self.x, self.x, input_pos=next_input_pos + ) # Self attention with input pos. + tt_res = self.tt_mha( + self.x, self.x, input_pos=next_input_pos + ) # Self attention with input pos. + self.assertTrue(torch.allclose(et_res, tt_res)) def test_attention_export(self): # Self attention. et_mha_ep = torch.export.export( self.et_mha, (self.x, self.x), - kwargs=None, + kwargs={"input_pos": self.input_pos}, dynamic_shapes=self.dynamic_shapes, ) - et_res = et_mha_ep.module()(self.x, self.x) - tt_res = self.tt_mha(self.x, self.x) + et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos) + tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) self.assertTrue(torch.allclose(et_res, tt_res)) # TODO: KV cache. @@ -126,7 +149,7 @@ def test_attention_executorch(self): et_mha_ep = torch.export.export( self.et_mha, (self.x, self.x), - kwargs=None, + kwargs={"input_pos": self.input_pos}, dynamic_shapes=self.dynamic_shapes, ) et_program = to_edge( @@ -136,8 +159,8 @@ def test_attention_executorch(self): runtime = Runtime.get() program = runtime.load_program(et_program.buffer) method = program.load_method("forward") - et_res = method.execute((self.x, self.x)) - tt_res = self.tt_mha(self.x, self.x) + et_res = method.execute((self.x, self.x, self.input_pos)) + tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))