From 6445107cc8a9ad16ca1bb982a358c150de53131e Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 15 Nov 2024 16:12:02 -0800 Subject: [PATCH] [Executorch][llama] Rename update_quantized_cache to update_cache Because it is just really an inplace update op. Nothing special Differential Revision: [D66041160](https://our.internmc.facebook.com/intern/diff/D66041160/) [ghstack-poisoned] --- .../quantized_kv_cache.py | 24 +++++++------------ extension/llm/custom_ops/TARGETS | 4 ++-- extension/llm/custom_ops/op_sdpa_aot.cpp | 23 ++++++++---------- ...uantized_cache.cpp => op_update_cache.cpp} | 8 +++---- ...te_quantized_cache.h => op_update_cache.h} | 2 +- .../llm/custom_ops/sdpa_with_kv_cache.py | 4 ++-- extension/llm/custom_ops/targets.bzl | 4 ++-- ...uantized_cache.py => test_update_cache.py} | 16 +++++-------- 8 files changed, 35 insertions(+), 50 deletions(-) rename extension/llm/custom_ops/{op_update_quantized_cache.cpp => op_update_cache.cpp} (95%) rename extension/llm/custom_ops/{op_update_quantized_cache.h => op_update_cache.h} (93%) rename extension/llm/custom_ops/{test_update_quantized_cache.py => test_update_cache.py} (93%) diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index 668ff378340..26567f3d52c 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -151,22 +151,14 @@ def update(self, input_pos, k_val, v_val): # instead of quantizing on their own. # But until this opting for code simplicity start_pos = input_pos[0].item() - _ = torch.ops.llama.update_quantized_cache( - quantized_k_val, self.k_cache, start_pos - ) - _ = torch.ops.llama.update_quantized_cache( - k_scales, self.k_cache_scales, start_pos - ) - _ = torch.ops.llama.update_quantized_cache( + _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) + _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) + _ = torch.ops.llama.update_cache( k_zero_points, self.k_cache_zero_points, start_pos ) - _ = torch.ops.llama.update_quantized_cache( - quantized_v_val, self.v_cache, start_pos - ) - _ = torch.ops.llama.update_quantized_cache( - v_scales, self.v_cache_scales, start_pos - ) - _ = torch.ops.llama.update_quantized_cache( + _ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos) + _ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos) + _ = torch.ops.llama.update_cache( v_zero_points, self.v_cache_zero_points, start_pos ) @@ -206,8 +198,8 @@ def update(self, input_pos, k_val, v_val): v_out[:, :, input_pos] = v_val else: start_pos = input_pos[0].item() - _ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos) - _ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos) + _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) + _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) return k_out, v_out diff --git a/extension/llm/custom_ops/TARGETS b/extension/llm/custom_ops/TARGETS index c12795fd249..5d0c0490506 100644 --- a/extension/llm/custom_ops/TARGETS +++ b/extension/llm/custom_ops/TARGETS @@ -23,9 +23,9 @@ runtime.python_test( ) runtime.python_test( - name = "test_update_quantized_cache", + name = "test_update_cache", srcs = [ - "test_update_quantized_cache.py", + "test_update_cache.py", ], preload_deps = [ ":custom_ops_aot_lib", diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 5d93df4a75d..b957a580787 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include @@ -127,22 +127,22 @@ at::Tensor custom_sdpa_aten( return output; } -Tensor& update_quantized_cache_out_no_context( +Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, const int64_t start_pos, Tensor& output) { exec_aten::RuntimeContext context{}; - return torch::executor::native::update_quantized_cache_out( + return torch::executor::native::update_cache_out( context, value, cache, start_pos, output); } -at::Tensor update_quantized_cache_aten( +at::Tensor update_cache_aten( const at::Tensor& value, at::Tensor& cache, const int64_t start_pos) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3) + WRAP_TO_ATEN(update_cache_out_no_context, 3) (value, cache, start_pos, output); return output; } @@ -169,10 +169,10 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " "float? scale=None, *, Tensor(a!) out) -> Tensor(a!)"); m.def( - "update_quantized_cache(Tensor value, Tensor(a!) cache, " + "update_cache(Tensor value, Tensor(a!) cache, " "SymInt start_pos) -> Tensor"); m.def( - "update_quantized_cache.out(Tensor value, Tensor(a!) cache, " + "update_cache.out(Tensor value, Tensor(a!) cache, " "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); } @@ -188,11 +188,8 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl( "custom_sdpa.out", WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8)); + m.impl("update_cache", torch::executor::native::update_cache_aten); m.impl( - "update_quantized_cache", - torch::executor::native::update_quantized_cache_aten); - m.impl( - "update_quantized_cache.out", - WRAP_TO_ATEN( - torch::executor::native::update_quantized_cache_out_no_context, 3)); + "update_cache.out", + WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); } diff --git a/extension/llm/custom_ops/op_update_quantized_cache.cpp b/extension/llm/custom_ops/op_update_cache.cpp similarity index 95% rename from extension/llm/custom_ops/op_update_quantized_cache.cpp rename to extension/llm/custom_ops/op_update_cache.cpp index 54ec999cb8f..740a0c6cd7e 100644 --- a/extension/llm/custom_ops/op_update_quantized_cache.cpp +++ b/extension/llm/custom_ops/op_update_cache.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include #include // @lint-ignore CLANGTIDY facebook-unused-include-check @@ -60,7 +60,7 @@ bool validate_cache_params( } } // anonymous namespace -Tensor& update_quantized_cache_out( +Tensor& update_cache_out( RuntimeContext& ctx, const Tensor& value, Tensor& cache, @@ -139,5 +139,5 @@ Tensor& update_quantized_cache_out( // In later diffs will rename this to update_cache. EXECUTORCH_LIBRARY( llama, - "update_quantized_cache.out", - torch::executor::native::update_quantized_cache_out); + "update_cache.out", + torch::executor::native::update_cache_out); diff --git a/extension/llm/custom_ops/op_update_quantized_cache.h b/extension/llm/custom_ops/op_update_cache.h similarity index 93% rename from extension/llm/custom_ops/op_update_quantized_cache.h rename to extension/llm/custom_ops/op_update_cache.h index 9cd8090839a..cf518b4e108 100644 --- a/extension/llm/custom_ops/op_update_quantized_cache.h +++ b/extension/llm/custom_ops/op_update_cache.h @@ -15,7 +15,7 @@ namespace executor { namespace native { -Tensor& update_quantized_cache_out( +Tensor& update_cache_out( RuntimeContext& ctx, const Tensor& value, Tensor& cache, diff --git a/extension/llm/custom_ops/sdpa_with_kv_cache.py b/extension/llm/custom_ops/sdpa_with_kv_cache.py index 85021266b59..be71425582c 100644 --- a/extension/llm/custom_ops/sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/sdpa_with_kv_cache.py @@ -203,8 +203,8 @@ def _validate_update_cache_params( ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" -@impl(custom_ops_lib, "update_quantized_cache", "Meta") -def update_quantized_cache_meta( +@impl(custom_ops_lib, "update_cache", "Meta") +def update_cache_meta( value, cache, start_pos, diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 781225afedc..bb59f48a279 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -22,13 +22,13 @@ def define_common_targets(): "op_fallback.cpp", "op_fast_hadamard_transform.cpp", "op_sdpa.cpp", - "op_update_quantized_cache.cpp", + "op_update_cache.cpp", ], exported_headers = [ "op_fallback.h", "op_fast_hadamard_transform.h", "op_sdpa.h", - "op_update_quantized_cache.h", + "op_update_cache.h", ], preprocessor_flags = get_vec_preprocessor_flags(), exported_deps = [ diff --git a/extension/llm/custom_ops/test_update_quantized_cache.py b/extension/llm/custom_ops/test_update_cache.py similarity index 93% rename from extension/llm/custom_ops/test_update_quantized_cache.py rename to extension/llm/custom_ops/test_update_cache.py index 75e1f4cc6ae..1d2f392c129 100644 --- a/extension/llm/custom_ops/test_update_quantized_cache.py +++ b/extension/llm/custom_ops/test_update_cache.py @@ -67,17 +67,13 @@ def _update_and_validate( self._update_k(start_pos, k, k_scales, k_zero_points) self._update_v(start_pos, v, v_scales, v_zero_points) - torch.ops.llama.update_quantized_cache(k, k_cache, start_pos) - torch.ops.llama.update_quantized_cache(k_scales, k_scales_cache, start_pos) - torch.ops.llama.update_quantized_cache( - k_zero_points, k_zero_points_cache, start_pos - ) + torch.ops.llama.update_cache(k, k_cache, start_pos) + torch.ops.llama.update_cache(k_scales, k_scales_cache, start_pos) + torch.ops.llama.update_cache(k_zero_points, k_zero_points_cache, start_pos) - torch.ops.llama.update_quantized_cache(v, v_cache, start_pos) - torch.ops.llama.update_quantized_cache(v_scales, v_scales_cache, start_pos) - torch.ops.llama.update_quantized_cache( - v_zero_points, v_zero_points_cache, start_pos - ) + torch.ops.llama.update_cache(v, v_cache, start_pos) + torch.ops.llama.update_cache(v_scales, v_scales_cache, start_pos) + torch.ops.llama.update_cache(v_zero_points, v_zero_points_cache, start_pos) self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache)) self.assertTrue(torch.allclose(v_cache, self.quantized_v_cache))