From 418ede335fba6cdab74085c25ee61742b6cae90a Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 27 Nov 2024 19:40:00 -0800 Subject: [PATCH] Dont quantize the current token for attention Pull Request resolved: https://github.com/pytorch/executorch/pull/5715 ghstack-source-id: 255730816 @exported-using-ghexport Differential Revision: [D63497872](https://our.internmc.facebook.com/intern/diff/D63497872/) --- .../quantized_kv_cache.py | 20 +++++++++++++++++++ .../test_sdpa_with_quantized_kv_cache.py | 6 ------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index 6d92a45e800..306c7380ecf 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -188,6 +188,26 @@ def update(self, input_pos, k_val, v_val): self.quantized_cache_dtype, self.cache_fp_type, ) + + if self.is_transposed: + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + dim_to_slice = 2 if self.is_transposed else 1 + torch._check(start_pos < self.k_cache.size(dim_to_slice)) + seq_length = k_val.size(dim_to_slice) + narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length) + narrowed_k.copy_(k_val) + narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length) + narrowed_v.copy_(v_val) + else: + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + else: + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos) + _ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos) + return k_out, v_out @classmethod diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index 65c6678ab25..21952d8c211 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -66,12 +66,6 @@ def test_simple(self, is_dynamic_shape=False): torch.testing.assert_close( float_out, quantized_out, - # had to adjust rtol because switching to using custom_sdpa means we - # will use dequantized k and v instead of original k and v - # this leads to larger differences in the output. - # subsequent diff in the stack will address this issue. - rtol=1e-01, - atol=1e-03, ) input_pos = torch.tensor([3], dtype=torch.int64)