From 54c52c38c12b3cb78b478e5f8410a145fcff68f1 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 25 Sep 2024 14:42:17 -0700 Subject: [PATCH] [ExecuTorch] Some updated to kv cache Update kv cache impl to consider untransposed cache Differential Revision: [D62301843](https://our.internmc.facebook.com/intern/diff/D62301843/) [ghstack-poisoned] --- examples/models/llama2/llama_transformer.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 65090e2fe5a..8e17013ae3d 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -151,6 +151,7 @@ def __init__( ): super().__init__() self.max_seq_length = max_seq_length + self.is_tranposed = transpose_cache if transpose_cache: cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) else: @@ -173,19 +174,21 @@ def update( ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() + start_pos = input_pos[0].item() torch._check_is_size(start_pos) torch._check(start_pos < self.max_seq_length) - seq_length = k_val.size(2) + dim_to_slice = 2 if self.transpose_cache else 1 + seq_length = k_val.size(dim_to_slice) # Replace the entry in the cache for this token # The following lines are equivalent to: # cache_k[:bsz, start_pos : start_pos + seqlen] = xk # cache_v[:bsz, start_pos : start_pos + seqlen] = xv + # when dim_to_slice is 1 # We use .narrow() here to make the compiler happy # pyre-ignore: Incompatible parameter type [6] - narrowed_k = self.k_cache.narrow(2, start_pos, seq_length) + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) # pyre-ignore: Incompatible parameter type [6] - narrowed_v = self.v_cache.narrow(2, start_pos, seq_length) + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) narrowed_k.copy_(k_val) narrowed_v.copy_(v_val) @@ -193,8 +196,12 @@ def update( else: k_out = self.k_cache v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val + if self.transpose_cache: + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + else: + k_out[:, input_pos] = k_val + v_out[:, input_pos] = v_val return k_out, v_out