diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index ae372bc6767..a30592080a5 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -160,13 +160,44 @@ runtime.python_library( ], ) +runtime.python_library( + name = "sdpa", + srcs = [ + "source_transformation/sdpa.py", + ], + _is_external_target = True, + visibility = ["//executorch/..."], + deps = [ + "//caffe2:torch", + ], +) + runtime.python_test( name = "quantized_kv_cache_test", srcs = [ "source_transformation/test_quantized_kv_cache.py", ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + ], + deps = [ + ":quantized_kv_cache", + "//caffe2:torch", + "//executorch/examples/models/llama2:llama_transformer", + ], +) + +runtime.python_test( + name = "quantized_sdpa_with_kv_cache_test", + srcs = [ + "source_transformation/test_sdpa_with_quantized_kv_cache.py", + ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + ], deps = [ ":quantized_kv_cache", + ":sdpa", "//caffe2:torch", "//executorch/examples/models/llama2:llama_transformer", ], diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index fa04a13a72f..2b43274760a 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -890,9 +890,7 @@ def _get_source_transforms( # noqa transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: - assert ( - args.use_kv_cache and not args.use_sdpa_with_kv_cache - ), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False" + assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" transforms.append(replace_kv_cache_with_quantized_kv_cache) if args.use_kv_cache: diff --git a/examples/models/llama2/source_transformation/quantized_kv_cache.py b/examples/models/llama2/source_transformation/quantized_kv_cache.py index cffe37fd7ae..8eec7846d3c 100644 --- a/examples/models/llama2/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama2/source_transformation/quantized_kv_cache.py @@ -47,6 +47,7 @@ def __init__( raise ValueError( f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}" ) + # For now supporting int8 only self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 @@ -65,10 +66,10 @@ def __init__( "v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) ) self.register_buffer( - "k_cache_scales", torch.ones(scale_shape, dtype=torch.double) + "k_cache_scales", torch.ones(scale_shape, dtype=torch.float64) ) self.register_buffer( - "v_cache_scales", torch.ones(scale_shape, dtype=torch.double) + "v_cache_scales", torch.ones(scale_shape, dtype=torch.float64) ) if cache_type == QuantizedCacheType.AffineAsymmetric: self.register_buffer( @@ -100,47 +101,74 @@ def update(self, input_pos, k_val, v_val): quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - dim_to_slice = 2 if self.is_transposed else 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k_scales = self.k_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k_zp = self.k_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k.copy_(quantized_k_val) - narrowed_k_scales.copy_(k_scales) - narrowed_k_zp.copy_(k_zero_points) - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v_scales = self.v_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v_zp = self.v_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v.copy_(quantized_v_val) - narrowed_v_scales.copy_(v_scales) - narrowed_v_zp.copy_(v_zero_points) - else: - if self.is_transposed: + if self.is_transposed: + # We cannot use update_cache op at the moment + # if the cache is transposed + # Also note that we shold not need separate paths + # for dynamic shape vs ! + # Only reason it is done this way is to accommodate + # for lowering pains of backends that work better + # with index_put op. + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + dim_to_slice = 2 if self.is_transposed else 1 + torch._check(start_pos < self.k_cache.size(dim_to_slice)) + seq_length = k_val.size(dim_to_slice) + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) + narrowed_k_scales = self.k_cache_scales.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_k_zp = self.k_cache_zero_points.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_k.copy_(quantized_k_val) + narrowed_k_scales.copy_(k_scales) + narrowed_k_zp.copy_(k_zero_points) + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) + narrowed_v_scales = self.v_cache_scales.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_v_zp = self.v_cache_zero_points.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_v.copy_(quantized_v_val) + narrowed_v_scales.copy_(v_scales) + narrowed_v_zp.copy_(v_zero_points) + else: self.k_cache[:, :, input_pos] = quantized_k_val self.k_cache_scales[:, :, input_pos] = k_scales self.k_cache_zero_points[:, :, input_pos] = k_zero_points self.v_cache[:, :, input_pos] = quantized_v_val self.v_cache_scales[:, :, input_pos] = v_scales self.v_cache_zero_points[:, :, input_pos] = v_zero_points - else: - self.k_cache[:, input_pos] = quantized_k_val - self.k_cache_scales[:, input_pos] = k_scales - self.k_cache_zero_points[:, input_pos] = k_zero_points - self.v_cache[:, input_pos] = quantized_v_val - self.v_cache_scales[:, input_pos] = v_scales - self.v_cache_zero_points[:, input_pos] = v_zero_points + else: + # Right now using custom ops on this path. + # In future we can update custom op to handle transposed cache + # as well. + # Note that we may have to revert this change if other ET + # backends such as QNN want to use quantized cache, with dynamic shape, + # instead of quantizing on their own. + # But until this opting for code simplicity + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_quantized_cache( + quantized_k_val, self.k_cache, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + k_scales, self.k_cache_scales, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + k_zero_points, self.k_cache_zero_points, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + quantized_v_val, self.v_cache, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + v_scales, self.v_cache_scales, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + v_zero_points, self.v_cache_zero_points, start_pos + ) k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 83f02623892..02ff7ee08f4 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -9,23 +9,32 @@ # Example script for exporting Llama2 to flatbuffer import math -from typing import Tuple +from typing import Tuple, Union import torch from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA +from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import ( + QuantizedKVCache, +) class SDPACustom(torch.nn.Module): def __init__( self, - kv_cache: KVCache, + kv_cache: Union[KVCache, QuantizedKVCache], dim: int, ): super().__init__() # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. - self.kv_cache = kv_cache.to(torch.float) + self.kv_cache = kv_cache + if not isinstance(kv_cache, QuantizedKVCache): + self.kv_cache = kv_cache.to(torch.float) + else: + assert ( + kv_cache.cache_fp_type == torch.float32 + ), "Only float32 is supported for custom SDPA" self.dim = dim def forward( @@ -44,12 +53,27 @@ def forward( q = q.to(dtype=torch.float) k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) + + k_cache = self.kv_cache.k_cache + v_cache = self.kv_cache.v_cache + if isinstance(self.kv_cache, QuantizedKVCache): + # updated quantize cache, scale and zero points + # returns dequantized kv cache + # Not most optimal. Optimizations to follow next + k_cache, v_cache = self.kv_cache.update(input_pos, k, v) + # Note that this path will still inplace mutate the k_cache, v_cache. + # WHen we are not using quantized kv cache, this will just mutate + # the original kv cache. + # When we aer using quantized kv cache, this will mutate + # k_cache, v_cache that is returned from cache update operation. + # This operation just dequantized thee cache and returns that. + # Future diffs will optimize this output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, - self.kv_cache.k_cache, - self.kv_cache.v_cache, + k_cache, + v_cache, input_pos[-1].item(), seqlen, None, # Attention mask diff --git a/examples/models/llama2/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama2/source_transformation/test_sdpa_with_quantized_kv_cache.py new file mode 100644 index 00000000000..8be34b2182b --- /dev/null +++ b/examples/models/llama2/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.examples.models.llama2.llama_transformer import KVCache + +from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import ( + QuantizedCacheType, + QuantizedKVCache, +) + +from executorch.examples.models.llama2.source_transformation.sdpa import SDPACustom + + +class SDPAWithQuantizedKVCacheTest(unittest.TestCase): + + def _init_cache(self): + self.kv_cache = KVCache( + self.max_batch_size, + self.max_seq_len, + self.n_kv_heads, + self.head_dim, + False, + self.enable_dynamic_shape, + dtype=self.dtype, + ) + self.quantized_kv_cache = QuantizedKVCache.from_float( + self.kv_cache, QuantizedCacheType.AffineAsymmetric + ) + + def _init_kv(self): + kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim) + q_shape = (1, self.seq_len, self.n_heads, self.head_dim) + q = torch.rand(q_shape, dtype=self.dtype) + k = torch.rand(kv_shape, dtype=self.dtype) + v = torch.rand(kv_shape, dtype=self.dtype) + return q, k, v + + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.max_seq_len = 5 + self.n_kv_heads = 4 + self.n_heads = 8 + self.head_dim = 17 + self.dim = self.n_heads * self.head_dim + self.enable_dynamic_shape = False + self.dtype = torch.float32 + + def test_simple(self, is_dynamic_shape=False): + self.enable_dynamic_shape = is_dynamic_shape + input_pos = torch.tensor([0], dtype=torch.int64) + self.seq_len = 3 + self._init_cache() + q, k, v = self._init_kv() + self.float_sdpa = SDPACustom(self.kv_cache, self.dim) + self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim) + float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + torch.testing.assert_close( + float_out, + quantized_out, + ) + + input_pos = torch.tensor([3], dtype=torch.int64) + self.seq_len = 1 + q, k, v = self._init_kv() + float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + torch.testing.assert_close( + float_out, + quantized_out, + rtol=1e-03, + atol=1e-03, + )