From 8b5cd86314c0159f3e41da00055c1427997b2b5e Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 26 Sep 2024 16:58:04 -0700 Subject: [PATCH] Dont quantize the current token for attention Differential Revision: [D63497872](https://our.internmc.facebook.com/intern/diff/D63497872/) [ghstack-poisoned] --- .../quantized_kv_cache.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/examples/models/llama2/source_transformation/quantized_kv_cache.py b/examples/models/llama2/source_transformation/quantized_kv_cache.py index 97583b05ab8..2627afc8e4e 100644 --- a/examples/models/llama2/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama2/source_transformation/quantized_kv_cache.py @@ -189,6 +189,27 @@ def update(self, input_pos, k_val, v_val): self.quantized_cache_dtype, self.cache_fp_type, ) + + if self.is_transposed: + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + dim_to_slice = 2 if self.is_transposed else 1 + torch._check(start_pos < self.k_cache.size(dim_to_slice)) + seq_length = k_val.size(dim_to_slice) + narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length) + narrowed_k.copy_(k_val) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length) + narrowed_v.copy_(v_val) + else: + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + else: + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos) + _ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos) + return k_out, v_out @classmethod