From 3de8619bfc11a99b42975b41ad588faeeebd948c Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 20 Sep 2024 15:50:42 -0700 Subject: [PATCH] Add update_quantized_cache op Why? - ton of copies due to functionalization - mutable buffer support without such custom inplace ops will results in giant copies at the end - Making inplace ops work will likely take longer and not clear safe path Differential Revision: [D62301838](https://our.internmc.facebook.com/intern/diff/D62301838/) [ghstack-poisoned] --- extension/llm/custom_ops/TARGETS | 13 ++ extension/llm/custom_ops/op_sdpa_aot.cpp | 41 ++++- .../custom_ops/op_update_quantized_cache.cpp | 114 +++++++++++++ .../custom_ops/op_update_quantized_cache.h | 26 +++ .../llm/custom_ops/sdpa_with_kv_cache.py | 52 ++++++ extension/llm/custom_ops/targets.bzl | 21 +++ .../custom_ops/test_update_quantized_cache.py | 150 ++++++++++++++++++ 7 files changed, 415 insertions(+), 2 deletions(-) create mode 100644 extension/llm/custom_ops/op_update_quantized_cache.cpp create mode 100644 extension/llm/custom_ops/op_update_quantized_cache.h create mode 100644 extension/llm/custom_ops/test_update_quantized_cache.py diff --git a/extension/llm/custom_ops/TARGETS b/extension/llm/custom_ops/TARGETS index 8fe776ab095..c12795fd249 100644 --- a/extension/llm/custom_ops/TARGETS +++ b/extension/llm/custom_ops/TARGETS @@ -22,6 +22,19 @@ runtime.python_test( ], ) +runtime.python_test( + name = "test_update_quantized_cache", + srcs = [ + "test_update_quantized_cache.py", + ], + preload_deps = [ + ":custom_ops_aot_lib", + ], + deps = [ + "//caffe2:torch", + ], +) + runtime.python_test( name = "test_preprocess_custom_ops", srcs = [ diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 79a6fa4dd9e..f3674088fd7 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -16,7 +17,6 @@ namespace torch { namespace executor { namespace native { -namespace { Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -81,7 +81,27 @@ at::Tensor sdpa_with_kv_cache_aten( output); return output; } -} // namespace + +Tensor& update_quantized_cache_out_no_context( + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + Tensor& output) { + exec_aten::RuntimeContext context{}; + return torch::executor::native::update_quantized_cache_out( + context, value, cache, start_pos, output); +} + +at::Tensor update_quantized_cache_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos) { + auto output = at::empty({1}); + WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3) + (value, cache, start_pos, output); + return output; +} + } // namespace native } // namespace executor } // namespace torch @@ -95,6 +115,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"); + m.def( + "update_quantized_cache(Tensor value, Tensor(a!) cache, " + "SymInt start_pos) -> Tensor"); + m.def( + "update_quantized_cache.out(Tensor value, Tensor(a!) cache, " + "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); } TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { @@ -105,3 +131,14 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { WRAP_TO_ATEN( torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); } + +// TODO: Rename this file to op_custom_ops_aot.cpp +TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { + m.impl( + "update_quantized_cache", + torch::executor::native::update_quantized_cache_aten); + m.impl( + "update_quantized_cache.out", + WRAP_TO_ATEN( + torch::executor::native::update_quantized_cache_out_no_context, 3)); +} diff --git a/extension/llm/custom_ops/op_update_quantized_cache.cpp b/extension/llm/custom_ops/op_update_quantized_cache.cpp new file mode 100644 index 00000000000..fa4f9c48b17 --- /dev/null +++ b/extension/llm/custom_ops/op_update_quantized_cache.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +// @lint-ignore CLANGTIDY facebook-unused-include-check +#include + +#include +// patternlint-disable-next-line executorch-cpp-nostdinc +#include + +#include + +namespace torch { +namespace executor { + +namespace native { + +namespace { +bool validate_cache_params( + const Tensor& quantized_value, + const Tensor& quantized_cache, + int64_t start_pos, + int64_t seq_length) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + quantized_cache.dim() == 4, "quantized cache must be a 4D tensor"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + quantized_value.dim() == 4, "quantized_value must be a 4D tensor"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + start_pos < quantized_cache.size(1), + "start_pos must be less than cache size at dim 1"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + (start_pos + seq_length) <= quantized_cache.size(1), + "start_post + seq_length must be less than max seq length supported by cache." + "start pos: %" PRId64 ", seq_length: %" PRId64 + "." + "cache size: %zd", + start_pos, + seq_length, + quantized_cache.size(1)); + + // Make sure they are in contiguous dim order + ET_LOG_MSG_AND_RETURN_IF_FALSE( + is_contiguous_dim_order( + quantized_cache.dim_order().data(), quantized_cache.dim()), + "quantized cache must be in contiguous dim order"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + is_contiguous_dim_order( + quantized_value.dim_order().data(), quantized_value.dim()), + "quantized value must be in contiguous dim order"); + + return true; +} +} // anonymous namespace + +Tensor& update_quantized_cache_out( + RuntimeContext& ctx, + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + Tensor& output) { + (void)ctx; + int64_t seq_len = value.size(1); + ET_KERNEL_CHECK( + ctx, + validate_cache_params(value, cache, start_pos, seq_len), + InvalidArgument, + output); + ET_CHECK_MSG(value.dim() == 4, "value must be a 4D tensor"); + + ET_CHECK_MSG(value.size(0) == 1, "value must have batch size of 1"); + ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1"); + const void* value_data = value.const_data_ptr(); + void* cache_data = cache.mutable_data_ptr(); + + ET_CHECK_MSG(value_data != nullptr, "projected_value data is null"); + ET_CHECK_MSG(cache_data, "cache data is null"); + + auto strides = cache.strides(); + exec_aten::StridesType seq_dim_stride = strides[1]; + exec_aten::SizesType pos_offset = start_pos * seq_dim_stride; + exec_aten::SizesType pos_offset_bytes = pos_offset * value.element_size(); + exec_aten::SizesType num_bytes = value.numel() * value.element_size(); + // NOLINTNEXTLINE + std::memcpy((uint8_t*)cache_data + pos_offset_bytes, value_data, num_bytes); + + // Noone uses output. Just a placeholder. + return output; +} +} // namespace native +} // namespace executor +} // namespace torch + +// Really this is just an inplace tensor update op +// which makes assumption on the rank of a tensor, +// and the dim order (memory layout) of the tensor. +// Furthermore assumes that the indexing is along +// sequence dimension (dim 1) of the tensor. +// In later diffs will rename this to update_cache. +EXECUTORCH_LIBRARY( + llama, + "update_quantized_cache.out", + torch::executor::native::update_quantized_cache_out); diff --git a/extension/llm/custom_ops/op_update_quantized_cache.h b/extension/llm/custom_ops/op_update_quantized_cache.h new file mode 100644 index 00000000000..9cd8090839a --- /dev/null +++ b/extension/llm/custom_ops/op_update_quantized_cache.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { + +namespace native { + +Tensor& update_quantized_cache_out( + RuntimeContext& ctx, + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + Tensor& output); +} // namespace native +} // namespace executor +} // namespace torch diff --git a/extension/llm/custom_ops/sdpa_with_kv_cache.py b/extension/llm/custom_ops/sdpa_with_kv_cache.py index 3de034fa6b5..d6c7fbab6f4 100644 --- a/extension/llm/custom_ops/sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/sdpa_with_kv_cache.py @@ -17,6 +17,7 @@ from torch.library import impl +# TODO rename this file to custom_ops_meta_registration.py try: op = torch.ops.llama.sdpa_with_kv_cache.default assert op is not None @@ -138,3 +139,54 @@ def fast_hadamard_transform_meta(mat): # assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!") # assert(mat.is_contiguous(), "input matrix must be contiguous currently!") return torch.empty_like(mat) + + +def _validate_update_cache_params( + value, + cache, + start_pos, +): + seq_len = value.size(1) + assert ( + value.dim() == 4 + ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." + + assert ( + value.dtype == cache.dtype + ), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}" + + for i in [0, 2, 3]: + assert value.size(i) == cache.size( + i + ), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}" + + torch._check_is_size(start_pos) + # Setting to arbitrary limit of 256 for now since there is no way + # to plumb this information from model config + torch._check(start_pos < cache.size(1)) + assert start_pos < cache.size( + 1 + ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" + + torch._check((start_pos + seq_len) < cache.size(1)) + assert (start_pos + seq_len) < cache.size( + 1 + ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" + + +@impl(custom_ops_lib, "update_quantized_cache", "Meta") +def update_quantized_cache_meta( + value, + cache, + start_pos, +): + _validate_update_cache_params( + value, + cache, + start_pos, + ) + + # Update cache doesnt really return anything but I dont know a better + # workaround. Should we just return cache instead? But I am afraid that + # will result in extra memory allocation + return torch.empty((1,), dtype=value.dtype, device="meta") diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 488f214e2bf..503e4a0c7bd 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -62,10 +62,31 @@ def define_common_targets(): ], deps = [ ":custom_ops" + mkl_dep, + ":update_quantized_cache", "//executorch/extension/aten_util:aten_bridge", ], ) + runtime.cxx_library( + name = "update_quantized_cache", + srcs = ["op_update_quantized_cache.cpp"], + exported_headers = ["op_update_quantized_cache.h"], + exported_deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/extension/kernel_util:kernel_util", + ], + compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"], + visibility = [ + "//executorch/...", + "//executorch/extension/llm/custom_ops/...", + "@EXECUTORCH_CLIENTS", + ], + # @lint-ignore BUCKLINT link_whole + link_whole = True, + force_static = True, + ) + runtime.python_library( name = "custom_ops_aot_py", srcs = [ diff --git a/extension/llm/custom_ops/test_update_quantized_cache.py b/extension/llm/custom_ops/test_update_quantized_cache.py new file mode 100644 index 00000000000..584048a563b --- /dev/null +++ b/extension/llm/custom_ops/test_update_quantized_cache.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch + + +class UpdateQuantizedKVCacheTest(unittest.TestCase): + + def _reset(self): + self.quantized_k_cache = torch.zeros( + (1, self.seq_len, self.num_heads, self.head_dim), dtype=torch.int8 + ) + self.quantized_v_cache = torch.zeros( + (1, self.seq_len, self.num_heads, self.head_dim), dtype=torch.int8 + ) + self.k_scales_cache = torch.zeros( + (1, self.seq_len, self.num_heads, 1), dtype=torch.float64 + ) + self.v_scales_cache = torch.zeros( + (1, self.seq_len, self.num_heads, 1), dtype=torch.float64 + ) + self.k_zero_points_cache = torch.zeros( + (1, self.seq_len, self.num_heads, 1), dtype=torch.int64 + ) + self.v_zero_points_cache = torch.zeros( + (1, self.seq_len, self.num_heads, 1), dtype=torch.int64 + ) + + def setUp(self): + torch.manual_seed(42) + self.seq_len = 10 + self.num_heads = 8 + self.head_dim = 4 + self._reset() + + def _update(self, start_pos, value, scales, zero_points, update_v=False): + seq_len = value.size(1) + if update_v: + self.quantized_v_cache[:, start_pos : start_pos + seq_len, :, :] = value + self.v_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales + self.v_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = ( + zero_points + ) + else: + self.quantized_k_cache[:, start_pos : start_pos + seq_len, :, :] = value + self.k_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales + self.k_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = ( + zero_points + ) + + def _update_and_validate( + self, k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ): + k_cache = self.quantized_k_cache.clone() + v_cache = self.quantized_v_cache.clone() + k_scales_cache = self.k_scales_cache.clone() + v_scales_cache = self.v_scales_cache.clone() + k_zero_points_cache = self.k_zero_points_cache.clone() + v_zero_points_cache = self.v_zero_points_cache.clone() + self._update(start_pos, k, k_scales, k_zero_points, update_v=False) + self._update(start_pos, v, v_scales, v_zero_points, update_v=True) + + torch.ops.llama.update_quantized_cache(k, k_cache, start_pos) + torch.ops.llama.update_quantized_cache(k_scales, k_scales_cache, start_pos) + torch.ops.llama.update_quantized_cache( + k_zero_points, k_zero_points_cache, start_pos + ) + + torch.ops.llama.update_quantized_cache(v, v_cache, start_pos) + torch.ops.llama.update_quantized_cache(v_scales, v_scales_cache, start_pos) + torch.ops.llama.update_quantized_cache( + v_zero_points, v_zero_points_cache, start_pos + ) + + self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache)) + self.assertTrue(torch.allclose(v_cache, self.quantized_v_cache)) + self.assertTrue(torch.allclose(k_scales_cache, self.k_scales_cache)) + self.assertTrue(torch.allclose(v_scales_cache, self.v_scales_cache)) + self.assertTrue(torch.allclose(k_zero_points_cache, self.k_zero_points_cache)) + self.assertTrue(torch.allclose(v_zero_points_cache, self.v_zero_points_cache)) + + def test_update_kv_cache_simple(self): + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + start_pos = 0 + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + def test_update_kv_cache_large_update(self): + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + start_pos = 0 + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + def test_update_kv_cache_update_nonzero_offset(self): + self._reset() + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + start_pos = 2 + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + def test_update_kv_cache_more_updates(self): + self._reset() + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + start_pos = 2 + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + start_pos = 4 + + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + )