diff --git a/extension/llm/custom_ops/TARGETS b/extension/llm/custom_ops/TARGETS index 8fe776ab095..c12795fd249 100644 --- a/extension/llm/custom_ops/TARGETS +++ b/extension/llm/custom_ops/TARGETS @@ -22,6 +22,19 @@ runtime.python_test( ], ) +runtime.python_test( + name = "test_update_quantized_cache", + srcs = [ + "test_update_quantized_cache.py", + ], + preload_deps = [ + ":custom_ops_aot_lib", + ], + deps = [ + "//caffe2:torch", + ], +) + runtime.python_test( name = "test_preprocess_custom_ops", srcs = [ diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index d9aa429fff0..1d31633fd32 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -16,7 +17,6 @@ namespace torch { namespace executor { namespace native { -namespace { Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -81,7 +81,27 @@ at::Tensor sdpa_with_kv_cache_aten( output); return output; } -} // namespace + +Tensor& update_quantized_cache_out_no_context( + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + Tensor& output) { + exec_aten::RuntimeContext context{}; + return torch::executor::native::update_quantized_cache_out( + context, value, cache, start_pos, output); +} + +at::Tensor update_quantized_cache_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos) { + auto output = at::empty({1}); + WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3) + (value, cache, start_pos, output); + return output; +} + } // namespace native } // namespace executor } // namespace torch @@ -95,6 +115,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"); + m.def( + "update_quantized_cache(Tensor value, Tensor(a!) cache, " + "SymInt start_pos) -> Tensor"); + m.def( + "update_quantized_cache.out(Tensor value, Tensor(a!) cache, " + "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); } TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { @@ -105,3 +131,14 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { WRAP_TO_ATEN( torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); } + +// TODO: Rename this file to op_custom_ops_aot.cpp +TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { + m.impl( + "update_quantized_cache", + torch::executor::native::update_quantized_cache_aten); + m.impl( + "update_quantized_cache.out", + WRAP_TO_ATEN( + torch::executor::native::update_quantized_cache_out_no_context, 3)); +} diff --git a/extension/llm/custom_ops/op_update_quantized_cache.cpp b/extension/llm/custom_ops/op_update_quantized_cache.cpp new file mode 100644 index 00000000000..54ec999cb8f --- /dev/null +++ b/extension/llm/custom_ops/op_update_quantized_cache.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +// @lint-ignore CLANGTIDY facebook-unused-include-check +#include + +#include + +namespace torch { +namespace executor { + +namespace native { + +namespace { +bool validate_cache_params( + const Tensor& quantized_value, + const Tensor& quantized_cache, + int64_t start_pos, + int64_t seq_length) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + quantized_cache.dim() == 4, "quantized cache must be a 4D tensor"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + quantized_value.dim() == 4, "quantized_value must be a 4D tensor"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + start_pos < quantized_cache.size(1), + "start_pos must be less than cache size at dim 1"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + (start_pos + seq_length) <= quantized_cache.size(1), + "start_post + seq_length must be less than max seq length supported by cache." + "start pos: %" PRId64 ", seq_length: %" PRId64 + "." + "cache size: %zd", + start_pos, + seq_length, + quantized_cache.size(1)); + + // Make sure they are in contiguous dim order + ET_LOG_MSG_AND_RETURN_IF_FALSE( + is_contiguous_dim_order( + quantized_cache.dim_order().data(), quantized_cache.dim()), + "quantized cache must be in contiguous dim order"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + is_contiguous_dim_order( + quantized_value.dim_order().data(), quantized_value.dim()), + "quantized value must be in contiguous dim order"); + + return true; +} +} // anonymous namespace + +Tensor& update_quantized_cache_out( + RuntimeContext& ctx, + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + Tensor& output) { + (void)ctx; + int64_t seq_len = value.size(1); + ET_KERNEL_CHECK( + ctx, + validate_cache_params(value, cache, start_pos, seq_len), + InvalidArgument, + output); + + ET_CHECK_MSG( + value.size(0) == cache.size(0), + "projected_value batch size should be equal to the cache batch size."); + ET_CHECK_MSG( + value.size(2) == cache.size(2), + "projected_value number of heads should be equal to the cache number of heads."); + ET_CHECK_MSG( + value.size(3) == cache.size(3), + "projected_value embedding dimension should be equal to the cache embedding dimension."); + ET_CHECK_MSG( + value.element_size() == cache.element_size(), + "projected_value data type size should be equal to the cache data type size."); + + ET_CHECK_MSG( + is_contiguous_dim_order(value.dim_order().data(), value.dim()), + "projected value must be in contiguous dim order"); + ET_CHECK_MSG( + is_contiguous_dim_order(cache.dim_order().data(), cache.dim()), + "projected value must be in contiguous dim order"); + + const void* value_data = value.const_data_ptr(); + void* cache_data = cache.mutable_data_ptr(); + + ET_CHECK_MSG(value_data, "projected_value data is null"); + ET_CHECK_MSG(cache_data, "cache data is null"); + + auto cache_strides = cache.strides(); + exec_aten::StridesType cache_batch_dim_stride = cache_strides[0]; + exec_aten::StridesType cache_seq_dim_stride = cache_strides[1]; + + auto value_strides = value.strides(); + exec_aten::StridesType value_batch_dim_stride = value_strides[0]; + + exec_aten::SizesType num_bytes_to_copy = + (value.numel() / value.size(0)) * value.element_size(); + + for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { + exec_aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + start_pos * cache_seq_dim_stride) * + cache.element_size(); + exec_aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride) * cache.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)value_data + value_pos_offset, + num_bytes_to_copy); + } + + // Noone uses output. Just a placeholder. + return output; +} +} // namespace native +} // namespace executor +} // namespace torch + +// Really this is just an inplace tensor update op +// which makes assumption on the rank of a tensor, +// and the dim order (memory layout) of the tensor. +// Furthermore assumes that the indexing is along +// sequence dimension (dim 1) of the tensor. +// In later diffs will rename this to update_cache. +EXECUTORCH_LIBRARY( + llama, + "update_quantized_cache.out", + torch::executor::native::update_quantized_cache_out); diff --git a/extension/llm/custom_ops/op_update_quantized_cache.h b/extension/llm/custom_ops/op_update_quantized_cache.h new file mode 100644 index 00000000000..9cd8090839a --- /dev/null +++ b/extension/llm/custom_ops/op_update_quantized_cache.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { + +namespace native { + +Tensor& update_quantized_cache_out( + RuntimeContext& ctx, + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + Tensor& output); +} // namespace native +} // namespace executor +} // namespace torch diff --git a/extension/llm/custom_ops/sdpa_with_kv_cache.py b/extension/llm/custom_ops/sdpa_with_kv_cache.py index 3de034fa6b5..d6c7fbab6f4 100644 --- a/extension/llm/custom_ops/sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/sdpa_with_kv_cache.py @@ -17,6 +17,7 @@ from torch.library import impl +# TODO rename this file to custom_ops_meta_registration.py try: op = torch.ops.llama.sdpa_with_kv_cache.default assert op is not None @@ -138,3 +139,54 @@ def fast_hadamard_transform_meta(mat): # assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!") # assert(mat.is_contiguous(), "input matrix must be contiguous currently!") return torch.empty_like(mat) + + +def _validate_update_cache_params( + value, + cache, + start_pos, +): + seq_len = value.size(1) + assert ( + value.dim() == 4 + ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." + + assert ( + value.dtype == cache.dtype + ), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}" + + for i in [0, 2, 3]: + assert value.size(i) == cache.size( + i + ), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}" + + torch._check_is_size(start_pos) + # Setting to arbitrary limit of 256 for now since there is no way + # to plumb this information from model config + torch._check(start_pos < cache.size(1)) + assert start_pos < cache.size( + 1 + ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" + + torch._check((start_pos + seq_len) < cache.size(1)) + assert (start_pos + seq_len) < cache.size( + 1 + ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" + + +@impl(custom_ops_lib, "update_quantized_cache", "Meta") +def update_quantized_cache_meta( + value, + cache, + start_pos, +): + _validate_update_cache_params( + value, + cache, + start_pos, + ) + + # Update cache doesnt really return anything but I dont know a better + # workaround. Should we just return cache instead? But I am afraid that + # will result in extra memory allocation + return torch.empty((1,), dtype=value.dtype, device="meta") diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 43fed39a5d5..c2843f5c2f7 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -13,11 +13,13 @@ def define_common_targets(): "op_fallback.cpp", "op_fast_hadamard_transform.cpp", "op_sdpa.cpp", + "op_update_quantized_cache.cpp", ], exported_headers = [ "op_fallback.h", "op_fast_hadamard_transform.h", "op_sdpa.h", + "op_update_quantized_cache.h", ], exported_deps = [ "//executorch/runtime/kernel:kernel_includes", diff --git a/extension/llm/custom_ops/test_update_quantized_cache.py b/extension/llm/custom_ops/test_update_quantized_cache.py new file mode 100644 index 00000000000..75e1f4cc6ae --- /dev/null +++ b/extension/llm/custom_ops/test_update_quantized_cache.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch + + +class UpdateQuantizedKVCacheTest(unittest.TestCase): + + def _reset(self): + self.quantized_k_cache = torch.zeros( + (self.batch_size, self.seq_len, self.num_heads, self.head_dim), + dtype=torch.int8, + ) + self.quantized_v_cache = torch.zeros( + (self.batch_size, self.seq_len, self.num_heads, self.head_dim), + dtype=torch.int8, + ) + self.k_scales_cache = torch.zeros( + (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.float64 + ) + self.v_scales_cache = torch.zeros( + (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.float64 + ) + self.k_zero_points_cache = torch.zeros( + (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.int64 + ) + self.v_zero_points_cache = torch.zeros( + (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.int64 + ) + + def setUp(self): + torch.manual_seed(42) + self.batch_size = 1 + self.seq_len = 10 + self.num_heads = 8 + self.head_dim = 4 + self._reset() + + def _update_k(self, start_pos, value, scales, zero_points): + seq_len = value.size(1) + self.quantized_k_cache[:, start_pos : start_pos + seq_len, :, :] = value + self.k_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales + self.k_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = zero_points + + def _update_v(self, start_pos, value, scales, zero_points): + seq_len = value.size(1) + self.quantized_v_cache[:, start_pos : start_pos + seq_len, :, :] = value + self.v_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales + self.v_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = zero_points + + def _update_and_validate( + self, k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ): + k_cache = self.quantized_k_cache.clone() + v_cache = self.quantized_v_cache.clone() + k_scales_cache = self.k_scales_cache.clone() + v_scales_cache = self.v_scales_cache.clone() + k_zero_points_cache = self.k_zero_points_cache.clone() + v_zero_points_cache = self.v_zero_points_cache.clone() + self._update_k(start_pos, k, k_scales, k_zero_points) + self._update_v(start_pos, v, v_scales, v_zero_points) + + torch.ops.llama.update_quantized_cache(k, k_cache, start_pos) + torch.ops.llama.update_quantized_cache(k_scales, k_scales_cache, start_pos) + torch.ops.llama.update_quantized_cache( + k_zero_points, k_zero_points_cache, start_pos + ) + + torch.ops.llama.update_quantized_cache(v, v_cache, start_pos) + torch.ops.llama.update_quantized_cache(v_scales, v_scales_cache, start_pos) + torch.ops.llama.update_quantized_cache( + v_zero_points, v_zero_points_cache, start_pos + ) + + self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache)) + self.assertTrue(torch.allclose(v_cache, self.quantized_v_cache)) + self.assertTrue(torch.allclose(k_scales_cache, self.k_scales_cache)) + self.assertTrue(torch.allclose(v_scales_cache, self.v_scales_cache)) + self.assertTrue(torch.allclose(k_zero_points_cache, self.k_zero_points_cache)) + self.assertTrue(torch.allclose(v_zero_points_cache, self.v_zero_points_cache)) + + def test_update_kv_cache_simple(self): + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + start_pos = 0 + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + def test_update_kv_cache_large_update(self): + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + start_pos = 0 + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + def test_update_kv_cache_update_nonzero_offset(self): + self._reset() + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + start_pos = 2 + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + def test_update_kv_cache_more_updates(self): + self._reset() + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + start_pos = 2 + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + start_pos = 4 + + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + def test_batched_update_kv_cache_more_updates(self): + self.batch_size = 7 + self._reset() + k = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint( + 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 + ) + v_zero_points = torch.randint( + 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 + ) + start_pos = 2 + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + ) + + k = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) + v_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint( + 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 + ) + v_zero_points = torch.randint( + 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 + ) + start_pos = 4 + + self._update_and_validate( + k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos + )