diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 7ed858a33c5..fcdd093bb0b 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -73,6 +73,7 @@ runtime.python_library( "source_transformation/apply_spin_quant_r1_r2.py", "source_transformation/prune_output.py", "source_transformation/quantize.py", + "source_transformation/quantized_kv_cache.py", "source_transformation/rms_norm.py", "source_transformation/rope.py", "source_transformation/sdpa.py", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 04ccbcdea08..90a8ec1e80d 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -53,7 +53,11 @@ get_quant_embedding_transform, get_quant_weight_transform, ) +from .source_transformation.quantized_kv_cache import ( + replace_kv_cache_with_quantized_kv_cache, +) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm + from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis from .source_transformation.sdpa import ( replace_causal_mask, @@ -206,6 +210,12 @@ def build_args_parser() -> argparse.ArgumentParser: action="store_true", help="Whether or not to export a model using kv cache", ) + parser.add_argument( + "--quantize_kv_cache", + default=False, + action="store_true", + help="Whether or not to export a model using quantized kv cache", + ) parser.add_argument( "--num_sharding", type=int, @@ -428,7 +438,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: Returns a LLMEdgeManager prior to calling export_to_edge with quantizers """ - # load model from checkpoint and params.json checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None checkpoint_dir = ( @@ -446,6 +455,41 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: else: dtype_override = None + # source transforms + transforms = [] + if args.quantization_mode: + modelname = f"{modelname}_q" + transforms.append( + get_quant_weight_transform(args, dtype_override, verbose_export()) + ) + + if args.embedding_quantize: + modelname = f"{modelname}_e" + transforms.append(get_quant_embedding_transform(args)) + + if args.expand_rope_table: + transforms.append(materialze_broadcast_of_rope_freq_cis) + + if args.use_sdpa_with_kv_cache: + transforms.append(replace_sdpa_with_custom_op) + + if args.quantize_kv_cache: + assert ( + args.use_kv_cache and not args.use_sdpa_with_kv_cache + ), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False" + transforms.append(replace_kv_cache_with_quantized_kv_cache) + + if args.use_kv_cache: + if args.qnn: + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + + elif args.coreml or args.mps: + # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition + # to get free perf gain. + transforms.append(replace_sdpa_with_simple_sdpa) + transforms.append(replace_causal_mask) return ( _load_llama_model( modelname=modelname, diff --git a/examples/models/llama2/source_transformation/TARGETS b/examples/models/llama2/source_transformation/TARGETS new file mode 100644 index 00000000000..71687b8e1ff --- /dev/null +++ b/examples/models/llama2/source_transformation/TARGETS @@ -0,0 +1,28 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "quantized_kv_cache", + srcs = [ + "quantized_kv_cache.py", + ], + _is_external_target = True, + base_module = "executorch.examples.models.llama2.source_transformation", + visibility = ["//executorch/..."], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_test( + name = "quantized_kv_cache_test", + srcs = [ + "test_quantized_kv_cache.py", + ], + deps = [ + ":quantized_kv_cache", + "//caffe2:torch", + "//executorch/examples/models/llama2:llama_transformer", + ], +) diff --git a/examples/models/llama2/source_transformation/quantized_kv_cache.py b/examples/models/llama2/source_transformation/quantized_kv_cache.py new file mode 100644 index 00000000000..c46f4696252 --- /dev/null +++ b/examples/models/llama2/source_transformation/quantized_kv_cache.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from enum import Enum + +import torch +import torch.nn as nn +from executorch.examples.models.llama2.llama_transformer import KVCache +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 + + +""" + Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py +""" + + +# Doesnt have to abide by affine quantizaiton laws +# However, if we do implement quantized sdpa, then this might be handy +class QuantizedCacheType(Enum): + AffineSymmetric = 0 + AffineAsymmetric = 1 + AffineSymmetricGroupWise = 1 + AffineAsymmetricGroupWise = 2 + + +class QuantizedKVCache(nn.Module): + def __init__( + self, + max_batch_size, + max_seq_length, + n_heads, + head_dim, + cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, + tranposed=False, + enable_dynamic_shape=False, + ): + super().__init__() + if cache_type not in ( + QuantizedCacheType.AffineSymmetric, + QuantizedCacheType.AffineAsymmetric, + ): + + raise ValueError( + f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}" + ) + # For now supporting int8 only + self.quantized_cache_dtype = torch.int8 + self.cache_fp_type = torch.float32 + self.is_transposed = tranposed + self.enable_dynamic_shape = enable_dynamic_shape + if self.is_transposed: + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + scale_shape = (max_batch_size, n_heads, max_seq_length, 1) + else: + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + scale_shape = (max_batch_size, max_seq_length, n_heads, 1) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=torch.int8)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=torch.int8)) + self.register_buffer( + "k_cache_scales", torch.ones(scale_shape, dtype=torch.double) + ) + self.register_buffer( + "v_cache_scales", torch.ones(scale_shape, dtype=torch.double) + ) + if cache_type == QuantizedCacheType.AffineAsymmetric: + self.register_buffer( + "k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64) + ) + self.register_buffer( + "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64) + ) + + def update(self, input_pos, k_val, v_val): + # quantize current k_val and store it in the cache + k_scales, k_zero_points = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( + k_val, torch.int8 # no other value is supported by this op anyway + ) + ) + quantized_k_val = torch.ops.quantized_decomposed.quantize_per_token( + k_val, + k_scales, + k_zero_points, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + torch.int8, + ) + + v_scales, v_zero_points = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + v_val, torch.int8 + ) + ) + quantized_v_val = torch.ops.quantized_decomposed.quantize_per_token( + v_val, + v_scales, + v_zero_points, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + torch.int8, + ) + + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + if self.is_transposed: + dim_to_slice = 2 + else: + dim_to_slice = 1 + torch._check(start_pos < self.k_cache.size(dim_to_slice)) + seq_length = k_val.size(dim_to_slice) + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) + narrowed_k_scales = self.k_cache_scales.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_k_zp = self.k_cache_zero_points.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_k.copy_(quantized_k_val) + narrowed_k_scales.copy_(k_scales) + narrowed_k_zp.copy_(k_zero_points) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) + narrowed_v_scales = self.v_cache_scales.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_v_zp = self.v_cache_zero_points.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_v.copy_(quantized_v_val) + narrowed_v_scales.copy_(v_scales) + narrowed_v_zp.copy_(v_zero_points) + else: + if self.is_transposed: + self.k_cache[:, :, input_pos] = quantized_k_val + self.k_cache_scales[:, :, input_pos] = k_scales + self.k_cache_zero_points[:, :, input_pos] = k_zero_points + self.v_cache[:, :, input_pos] = quantized_v_val + self.v_cache_scales[:, :, input_pos] = v_scales + self.v_cache_zero_points[:, :, input_pos] = v_zero_points + else: + self.k_cache[:, input_pos] = quantized_k_val + self.k_cache_scales[:, input_pos] = k_scales + self.k_cache_zero_points[:, input_pos] = k_zero_points + self.v_cache[:, input_pos] = quantized_v_val + self.v_cache_scales[:, input_pos] = v_scales + self.v_cache_zero_points[:, input_pos] = v_zero_points + + k_out = torch.ops.quantized_decomposed.dequantize_per_token( + self.k_cache, + self.k_cache_scales, + self.k_cache_zero_points, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + self.quantized_cache_dtype, + self.cache_fp_type, + ) + v_out = torch.ops.quantized_decomposed.dequantize_per_token( + self.v_cache, + self.v_cache_scales, + self.v_cache_zero_points, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + self.quantized_cache_dtype, + self.cache_fp_type, + ) + return k_out, v_out + + @classmethod + def from_float(cls, kv_cache, cache_type: QuantizedCacheType): + cache_shape = kv_cache.k_cache.shape + if kv_cache.is_tranposed: + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape + else: + max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + return cls( + max_batch_size, + max_seq_length, + n_heads, + head_dim, + cache_type, + kv_cache.is_tranposed, + kv_cache.enable_dynamic_shape, + ) + + +def replace_kv_cache_with_quantized_kv_cache(module): + logging.warning( + "Replacing KVCache with QuantizedKVCache. This modifies the model in place." + ) + for name, child in module.named_children(): + if isinstance(child, KVCache): + setattr( + module, + name, + QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric), + ) + else: + replace_kv_cache_with_quantized_kv_cache(child) + return module diff --git a/examples/models/llama2/source_transformation/test_quantized_kv_cache.py b/examples/models/llama2/source_transformation/test_quantized_kv_cache.py new file mode 100644 index 00000000000..5fa5d1958de --- /dev/null +++ b/examples/models/llama2/source_transformation/test_quantized_kv_cache.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.examples.models.llama2.llama_transformer import KVCache + +from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import ( + QuantizedCacheType, + QuantizedKVCache, +) + + +class QuantizedKVCacheTest(unittest.TestCase): + + def _init_cache(self): + self.kv_cache = KVCache( + self.max_batch_size, + self.max_seq_len, + self.n_kv_heads, + self.head_dim, + self.transpose_kv_cache, + self.enable_dynamic_shape, + dtype=self.dtype, + ) + + def _init_kv(self): + if self.transpose_kv_cache: + shape = (1, self.n_kv_heads, self.seq_len, self.head_dim) + else: + shape = (1, self.seq_len, self.n_kv_heads, self.head_dim) + k = torch.rand(shape, dtype=self.dtype) + v = torch.rand(shape, dtype=self.dtype) + return k, v + + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.max_seq_len = 5 + self.n_kv_heads = 8 + self.head_dim = 17 + self.enable_dynamic_shape = False + self.transpose_kv_cache = False + self.dtype = torch.float32 + + def _test_simple_update_fetch(self, is_tranposed=False, is_dynamic_shape=False): + self.transpose_kv_cache = is_tranposed + self.enable_dynamic_shape = is_dynamic_shape + input_pos = torch.tensor([0, 1, 2]) + self.seq_len = input_pos.size(0) + self._init_cache() + k, v = self._init_kv() + quantized_kv_cache = QuantizedKVCache.from_float( + self.kv_cache, QuantizedCacheType.AffineAsymmetric + ) + updated_k_cache, updated_v_cache = self.kv_cache.update(input_pos, k, v) + updated_dequantized_k_cache, updated_dequantized_v_cache = ( + quantized_kv_cache.update(input_pos, k, v) + ) + + def index(t, input_pos): + if self.transpose_kv_cache: + return t[:, :, input_pos, :] + else: + return t[:, input_pos, :, :] + + sliced_k_cache = index(updated_k_cache, input_pos) + sliced_v_cache = index(updated_v_cache, input_pos) + + sliced_dequantized_k_cache = index(updated_dequantized_k_cache, input_pos) + sliced_dequantized_v_cache = index(updated_dequantized_v_cache, input_pos) + + torch.testing.assert_close( + sliced_k_cache, + sliced_dequantized_k_cache, + rtol=1e-02, + atol=1e-02, + ) + torch.testing.assert_close( + sliced_v_cache, + sliced_dequantized_v_cache, + rtol=1e-02, + atol=1e-02, + ) + + input_pos = torch.tensor([3]) + self.seq_len = input_pos.size(0) + k, v = self._init_kv() + pos_to_check = torch.tensor([0, 1, 2, 3]) + updated_k_cache, updated_v_cache = self.kv_cache.update(input_pos, k, v) + updated_dequantized_k_cache, updated_dequantized_v_cache = ( + quantized_kv_cache.update(input_pos, k, v) + ) + sliced_k_cache = index(updated_k_cache, pos_to_check) + sliced_v_cache = index(updated_v_cache, pos_to_check) + + sliced_dequantized_k_cache = index(updated_dequantized_k_cache, pos_to_check) + sliced_dequantized_v_cache = index(updated_dequantized_v_cache, pos_to_check) + + torch.testing.assert_close( + sliced_k_cache, + sliced_dequantized_k_cache, + rtol=1e-02, + atol=1e-02, + ) + torch.testing.assert_close( + sliced_v_cache, + sliced_dequantized_v_cache, + rtol=1e-02, + atol=1e-02, + ) + + def test_simple_update_fetch_not_transposed(self): + self._test_simple_update_fetch() + + def test_simple_update_fetch_not_transposed_dynamic_shape(self): + self._test_simple_update_fetch(is_dynamic_shape=True) + + def test_simple_update_fetch_transposed(self): + self._test_simple_update_fetch(is_tranposed=True) + + def test_simple_update_fetch_transposed_dynamic_shape(self): + self._test_simple_update_fetch(is_tranposed=True, is_dynamic_shape=True) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 91ee2dc733b..68a0965e1ac 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -146,6 +146,7 @@ def source_transform( if self.verbose: logging.info(f"Applied source transforms: {self.applied_source_transforms}") + logging.info(f"Model after source transforms: {self.model}") return self def _get_dynamic_shape(self) -> Any: diff --git a/kernels/quantized/cpu/op_choose_qparams.cpp b/kernels/quantized/cpu/op_choose_qparams.cpp index 13b26e0b738..9afb713c9f0 100644 --- a/kernels/quantized/cpu/op_choose_qparams.cpp +++ b/kernels/quantized/cpu/op_choose_qparams.cpp @@ -36,7 +36,8 @@ void check_quantize_per_tensor_args( int64_t qmax, ScalarType dtype, Tensor& scale_out, - Tensor& zero_point_out) { + Tensor& zero_point_out, + bool is_per_token = false) { (void)dtype; ET_CHECK_MSG( qmin < qmax, @@ -56,27 +57,49 @@ void check_quantize_per_tensor_args( zero_point_out.scalar_type() == ScalarType::Long, "Expected scale to be Long tensor received: %" PRId8, static_cast(zero_point_out.scalar_type())); - ET_CHECK_MSG( - scale_out.numel() == 1, - "Exepcted scale to only have one element received: %zd", - ssize_t(scale_out.numel())); - ET_CHECK_MSG( - zero_point_out.numel() == 1, - "Exepcted zero_point to only have one element received: %zd", - ssize_t(zero_point_out.numel())); + + if (is_per_token) { + for (auto i = 0; i < input.dim() - 1; i++) { + ET_CHECK_MSG( + scale_out.size(i) == input.size(i), + "Exepcted scale to have the same number of elements at dimentions %d got %zd", + i, + scale_out.size(i)); + ET_CHECK_MSG( + zero_point_out.size(i) == input.size(i), + "Exepcted zero pont to have the same number of elements at dimentions %d got %zd", + i, + zero_point_out.size(i)); + } + ET_CHECK_MSG( + scale_out.size(input.dim() - 1) == 1, + "Exepcted scale to have only one element at dimentions %ld but got %zd", + input.dim() - 1, + scale_out.size(input.dim() - 1)); + ET_CHECK_MSG( + zero_point_out.size(input.dim() - 1) == 1, + "Exepcted zero point to have only one element at dimentions %ld but got %zd", + input.dim() - 1, + zero_point_out.size(input.dim() - 1)); + } else { + ET_CHECK_MSG( + scale_out.numel() == 1, + "Exepcted scale to only have one element received: %zd", + ssize_t(scale_out.numel())); + ET_CHECK_MSG( + zero_point_out.numel() == 1, + "Exepcted zero_point to only have one element received: %zd", + ssize_t(zero_point_out.numel())); + } } -void choose_qparams( - const Tensor& input, +void calculate_scale_and_zero_point( + float min, + float max, int32_t qmin, int32_t qmax, - Tensor& scale_out, - Tensor& zero_point_out) { - const float* x_fp32 = input.const_data_ptr(); - // Compute x_min, x_max and q_params (scale, zero_point) - float min = torch::executor::vec_minf(x_fp32, input.numel()); - float max = torch::executor::vec_maxf(x_fp32, input.numel()); - + double& scale, + int32_t& zero_point) { // We extend the [min, max] interval to ensure that it contains 0. // Otherwise, we would not meet the requirement that 0 be an exactly // representable value. @@ -85,7 +108,7 @@ void choose_qparams( // Use double precision for intermediate computation but use single precision // in final number to reflect the actual number used during quantization. - double scale = (static_cast(max) - min) / (qmax - qmin); + scale = (static_cast(max) - min) / (qmax - qmin); // If scale is 0 or too small so its reciprocal is infinity, we arbitrary // adjust the scale to 0.1 . We want to avoid scale's reciprocal being // infinity because some of fbgemm code pre-computes scale's reciprocal to do @@ -143,9 +166,54 @@ void choose_qparams( } else { nudged_zero_point = nearbyint(static_cast(initial_zero_point)); } + zero_point = nudged_zero_point; + return; +} + +void choose_qparams( + const Tensor& input, + int32_t qmin, + int32_t qmax, + Tensor& scale_out, + Tensor& zero_point_out) { + const float* x_fp32 = input.const_data_ptr(); + // Compute x_min, x_max and q_params (scale, zero_point) + float min = torch::executor::vec_minf(x_fp32, input.numel()); + float max = torch::executor::vec_maxf(x_fp32, input.numel()); + + double scale; + int32_t zero_point; + calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point); scale_out.mutable_data_ptr()[0] = scale; - zero_point_out.mutable_data_ptr()[0] = nudged_zero_point; + zero_point_out.mutable_data_ptr()[0] = zero_point; +} + +void choose_qparams_per_token( + const Tensor& input, + int32_t qmin, + int32_t qmax, + Tensor& scale_out, + Tensor& zero_point_out) { + const float* x_fp32 = input.const_data_ptr(); + // Compute x_min, x_max and q_params (scale, zero_point) + auto num_tokens = 1; + for (auto i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + auto token_dim_size = input.size(input.dim() - 1); + for (auto i = 0; i < num_tokens; i++) { + // vec_minf uses std::min_element. Check if it actually + // gets vectorized. + float min = torch::executor::vec_minf(x_fp32, token_dim_size); + float max = torch::executor::vec_maxf(x_fp32, token_dim_size); + double scale; + int32_t zero_point; + calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point); + scale_out.mutable_data_ptr()[i] = scale; + zero_point_out.mutable_data_ptr()[i] = zero_point; + x_fp32 += token_dim_size; + } } } // namespace @@ -180,6 +248,54 @@ ::std::tuple choose_qparams_tensor_out( input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out); } +std::tuple choose_qparams_per_token_asymmetric_out( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + int64_t quant_min = -128; + int64_t quant_max = 127; + exec_aten::SizesType output_sizes[kTensorDimensionLimit]; + for (ssize_t i = 0; i < input.dim() - 1; i++) { + output_sizes[i] = input.size(i); + } + output_sizes[input.dim() - 1] = 1; + size_t output_dim = input.dim(); + torch::executor::Error err = + resize_tensor(scale_out, {output_sizes, output_dim}); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize scale_out Tensor in choose_qparams"); + err = resize_tensor(zero_point_out, {output_sizes, output_dim}); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize zero_point_out Tensor in choose_qparams"); + + check_quantize_per_tensor_args( + input, + quant_min, + quant_max, + dtype, + scale_out, + zero_point_out, + true /* is_per_token*/); + + choose_qparams_per_token( + input, quant_min, quant_max, scale_out, zero_point_out); + return {scale_out, zero_point_out}; +} + +::std::tuple choose_qparams_per_token_asymmetric_out( + RuntimeContext& context, + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + (void)context; + return choose_qparams_per_token_asymmetric_out( + input, dtype, scale_out, zero_point_out); +} + } // namespace native } // namespace executor } // namespace torch diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 87f65fc9cf0..dc94b520740 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -178,8 +178,6 @@ Tensor& dequantize_per_channel_out( ScalarType dtype, exec_aten::optional out_dtype, Tensor& out) { - torch::executor::Error err = resize_tensor(out, input.sizes()); - // normalize axis ET_CHECK_MSG( tensor_has_dim(input, axis), @@ -191,10 +189,6 @@ Tensor& dequantize_per_channel_out( axis += nonzero_dim(input); } - ET_CHECK_MSG( - err == torch::executor::Error::Ok, - "Failed to resize out Tensor in dequantize_per_channel_out"); - ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "scale.scalar_type() %" PRId8 " is not double type", @@ -335,6 +329,11 @@ Tensor& dequantize_per_channel_out( exec_aten::optional out_dtype, Tensor& out) { (void)context; + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_channel_out"); + return dequantize_per_channel_out( input, scale, @@ -381,6 +380,77 @@ Tensor& dequantize_per_tensor_tensor_args_out( input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); } +Tensor& dequantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) { + // Refactor this into a util + size_t num_channels = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + num_channels *= input.size(i); + } +// This unfortunate change is needed because we compile op_quantize for aten +// mode as well +#ifdef USE_ATEN_LIB + const std::array sizes = {{num_channels, input.dim() - 1}}; + Tensor reshaped_input = at::from_blob( + input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type())); +#else + std::array input_dim_order{0, 1}; + std::array input_sizes; + input_sizes[0] = num_channels; + input_sizes[1] = input.size(input.dim() - 1); + std::array input_strides; + dim_order_to_stride_nocheck( + input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); + void* input_data = input.mutable_data_ptr(); + TensorImpl reshaped_input_impl = TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_channel_out"); +#endif + + return dequantize_per_channel_out( + reshaped_input, + scale, + zero_points, + 0, + quant_min, + quant_max, + dtype, + out_dtype, + out); +} + +Tensor& dequantize_per_token_out( + RuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) { + (void)context; + return dequantize_per_token_out( + input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); +} + } // namespace native } // namespace executor } // namespace torch diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 065dc743d92..9e95b11d592 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -241,8 +241,6 @@ Tensor& quantize_per_channel_out( int64_t quant_max, ScalarType dtype, Tensor& out) { - torch::executor::Error err = resize_tensor(out, input.sizes()); - // normalize axis ET_CHECK_MSG( tensor_has_dim(input, axis), @@ -254,10 +252,6 @@ Tensor& quantize_per_channel_out( axis += nonzero_dim(input); } - ET_CHECK_MSG( - err == torch::executor::Error::Ok, - "Failed to resize out Tensor in quantize_per_channel_out"); - ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "scale.scalar_type() %" PRId8 " is not double type", @@ -368,9 +362,76 @@ Tensor& quantize_per_channel_out( ScalarType dtype, Tensor& out) { (void)context; + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_channel_out"); + return quantize_per_channel_out( input, scale, zero_point, axis, quant_min, quant_max, dtype, out); } + +Tensor& quantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + size_t num_tokens = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } +// This unfortunate change is needed because we compile op_quantize for aten +// mode as well +#ifdef USE_ATEN_LIB + std::vector sizes(2); + sizes[0] = num_tokens; + sizes[1] = input.size(input.dim() - 1); + Tensor reshaped_input = at::from_blob( + input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type())); +#else + std::array input_dim_order{0, 1}; + std::array input_sizes; + input_sizes[0] = num_tokens; + input_sizes[1] = input.size(input.dim() - 1); + std::array input_strides; + dim_order_to_stride_nocheck( + input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); + void* input_data = input.mutable_data_ptr(); + TensorImpl reshaped_input_impl = TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_channel_out"); +#endif + + return quantize_per_channel_out( + reshaped_input, scale, zero_point, 0, quant_min, quant_max, dtype, out); +} + +Tensor& quantize_per_token_out( + RuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + (void)context; + return quantize_per_token_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} } // namespace native } // namespace executor } // namespace torch diff --git a/kernels/quantized/quantized.yaml b/kernels/quantized/quantized.yaml index ca2360b7d80..eb7586bad77 100644 --- a/kernels/quantized/quantized.yaml +++ b/kernels/quantized/quantized.yaml @@ -81,3 +81,21 @@ kernels: - arg_meta: null kernel_name: torch::executor::quantize_per_tensor_tensor_args_out + +- func: quantized_decomposed::choose_qparams_per_token_asymmetric.out(Tensor input, ScalarType dtype, *, Tensor(a!) scale_out, Tensor(b!) zero_point_out) -> (Tensor(a!), Tensor(b!)) + variants: function + kernels: + - arg_meta: null + kernel_name: torch::executor::choose_qparams_per_token_asymmetric_out + +- func: quantized_decomposed::quantize_per_token.out(Tensor input, Tensor scales, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: torch::executor::quantize_per_token_out + +- func: quantized_decomposed::dequantize_per_token.out(Tensor input, Tensor scales, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: torch::executor::dequantize_per_token_out diff --git a/kernels/quantized/targets.bzl b/kernels/quantized/targets.bzl index fd35ad3728d..829cd7e9aeb 100644 --- a/kernels/quantized/targets.bzl +++ b/kernels/quantized/targets.bzl @@ -16,14 +16,17 @@ def define_common_targets(): ops = [ "quantized_decomposed::add.out", "quantized_decomposed::choose_qparams.Tensor_out", + "quantized_decomposed::choose_qparams_per_token_asymmetric.out", "quantized_decomposed::dequantize_per_channel.out", "quantized_decomposed::dequantize_per_tensor.out", "quantized_decomposed::dequantize_per_tensor.Tensor_out", + "quantized_decomposed::dequantize_per_token.out", "quantized_decomposed::mixed_linear.out", "quantized_decomposed::mixed_mm.out", "quantized_decomposed::quantize_per_channel.out", "quantized_decomposed::quantize_per_tensor.out", "quantized_decomposed::quantize_per_tensor.Tensor_out", + "quantized_decomposed::quantize_per_token.out", ], define_static_targets = True, ) diff --git a/kernels/quantized/test/TARGETS b/kernels/quantized/test/TARGETS index ec1ddacfc41..a820e3da3fa 100644 --- a/kernels/quantized/test/TARGETS +++ b/kernels/quantized/test/TARGETS @@ -1,4 +1,5 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load(":targets.bzl", "define_common_targets") oncall("executorch") @@ -18,3 +19,36 @@ python_unittest( "//executorch/kernels/quantized:quantized_ops_lib", ], ) + +runtime.cxx_library( + name = "quantized_ops_for_test_lib", + srcs = [ + "quantized_ops_aot_register.cpp", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/extension/aten_util:aten_bridge", + "//executorch/kernels/quantized/cpu:op_dequantize", + "//executorch/kernels/quantized/cpu:op_quantize", + "//executorch/runtime/core/exec_aten:lib", + ], + external_deps = [ + "libtorch", + ], +) + +python_unittest( + name = "test_quant_dequant_per_token", + srcs = [ + "test_quant_dequant_per_token.py", + ], + preload_deps = [ + ":quantized_ops_for_test_lib", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/kernels/quantized/test/op_choose_qparams_test.cpp b/kernels/quantized/test/op_choose_qparams_test.cpp index e7acfb0cf46..5cc3fc21169 100644 --- a/kernels/quantized/test/op_choose_qparams_test.cpp +++ b/kernels/quantized/test/op_choose_qparams_test.cpp @@ -11,6 +11,7 @@ #include #include #include + #include #include @@ -21,6 +22,7 @@ using exec_aten::ArrayRef; using exec_aten::Scalar; using exec_aten::ScalarType; using exec_aten::Tensor; +using torch::executor::native::choose_qparams_per_token_asymmetric_out; using torch::executor::native::choose_qparams_tensor_out; using torch::executor::testing::TensorFactory; @@ -28,6 +30,7 @@ using torch::executor::testing::TensorFactory; /// zeros(). template void test_dtype() { + et_pal_init(); TensorFactory tf_float; TensorFactory tf_double; TensorFactory tf_long; @@ -48,6 +51,115 @@ void test_dtype() { EXPECT_TENSOR_EQ(zero_point_out, expected_zero_point); } -TEST(OpChooseQparamsTensorOutTest, AllDtypesSupported) { - test_dtype(); +TEST(OpChooseQparamsPerTokenAsymmetricTensorOutTest, Float) { + et_pal_init(); + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.make({2, 3}, {-0.5, 0.3, 1.2, 0.1, -0.8, 2.1}); + Tensor scale_out = tf_double.zeros({2, 1}); + Tensor zero_point_out = tf_long.zeros({2, 1}); + Tensor expected_scale = tf_double.make({2, 1}, {0.00666667, 0.0113725485}); + Tensor expected_zero_point = tf_long.make({2, 1}, {-53, -58}); + + choose_qparams_per_token_asymmetric_out( + input, ScalarType::Float, scale_out, zero_point_out); + + EXPECT_TENSOR_CLOSE_WITH_TOL(scale_out, expected_scale, 1e-4, 1e-4); + EXPECT_TENSOR_EQ(zero_point_out, expected_zero_point); +} + +TEST(OpChooseQparamsPerTokenAsymmetricTensorOutTest, ExtraDimFloat) { + et_pal_init(); + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.make({1, 2, 3}, {-0.5, 0.3, 1.2, 0.1, -0.8, 2.1}); + Tensor scale_out = tf_double.zeros({1, 2, 1}); + Tensor zero_point_out = tf_long.zeros({1, 2, 1}); + Tensor expected_scale = tf_double.make({1, 2, 1}, {0.00666667, 0.0113725485}); + Tensor expected_zero_point = tf_long.make({1, 2, 1}, {-53, -58}); + + choose_qparams_per_token_asymmetric_out( + input, ScalarType::Float, scale_out, zero_point_out); + + EXPECT_TENSOR_CLOSE_WITH_TOL(scale_out, expected_scale, 1e-4, 1e-4); + EXPECT_TENSOR_EQ(zero_point_out, expected_zero_point); +} + +TEST(OpChooseQparamsPerTokenAsymmetricTensorOutTest, LargeArray) { + et_pal_init(); + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.make( + {5, 17}, + {0.41654, 0.26599, 0.4141, 0.83809, 0.02938, 0.12199, 0.53667, + 0.799, 0.6606, 0.46657, 0.66142, 0.71787, 0.56098, 0.30202, + 0.059377, 0.85473, 0.8017, 0.2703, 0.44299, 0.49045, 0.75581, + 0.24429, 0.43906, 0.78652, 0.83885, 0.31034, 0.76534, 0.74422, + 0.62549, 0.80006, 0.38144, 0.70652, 0.33553, 0.89136, 0.49126, + 0.072916, 0.75654, 0.82057, 0.083848, 0.29753, 0.62718, 0.95579, + 0.83097, 0.47293, 0.15666, 0.6248, 0.21672, 0.14626, 0.71834, + 0.93664, 0.23382, 0.68931, 0.70866, 0.60545, 0.98648, 0.30335, + 0.62439, 0.19195, 0.1923, 0.75638, 0.81114, 0.34778, 0.0070671, + 0.50918, 0.19698, 0.19969, 0.57687, 0.062786, 0.18447, 0.22961, + 0.29656, 0.25486, 0.75965, 0.11328, 0.86468, 0.21264, 0.99591, + 0.75231, 0.97834, 0.042441, 0.39978, 0.9633, 0.9297, 0.12188, + 0.73564}); + Tensor scale_out = tf_double.zeros({5, 1}); + Tensor zero_point_out = tf_long.zeros({5, 1}); + Tensor expected_scale = tf_double.make( + {5, 1}, {0.0033519, 0.0034955, 0.0037482, 0.0038685, 0.0039055}); + Tensor expected_zero_point = + tf_long.make({5, 1}, {-128, -128, -128, -128, -128}); + + choose_qparams_per_token_asymmetric_out( + input, ScalarType::Float, scale_out, zero_point_out); + + EXPECT_TENSOR_CLOSE_WITH_TOL(scale_out, expected_scale, 1e-5, 1e-5); + EXPECT_TENSOR_EQ(zero_point_out, expected_zero_point); +} + +TEST(OpChooseQparamsPerTokenAsymmetricTensorOutTest, DynamicShapeFloat) { + et_pal_init(); + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.make({1, 2, 3}, {-0.5, 0.3, 1.2, 0.1, -0.8, 2.1}); + Tensor scale_out = tf_double.zeros( + {1, 5, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + Tensor zero_point_out = tf_long.zeros( + {1, 5, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + Tensor expected_scale = tf_double.make({1, 2, 1}, {0.00666667, 0.0113725485}); + Tensor expected_zero_point = tf_long.make({1, 2, 1}, {-53, -58}); + + choose_qparams_per_token_asymmetric_out( + input, ScalarType::Float, scale_out, zero_point_out); + + EXPECT_TENSOR_CLOSE_WITH_TOL(scale_out, expected_scale, 1e-4, 1e-4); + EXPECT_TENSOR_EQ(zero_point_out, expected_zero_point); + + Tensor new_input = tf_float.make( + {1, 5, 8}, + {5.2254, 5.6041, 5.7653, -1.0126, -0.86126, -0.1606, -0.99196, + -1.067, 5.5913, 5.7713, 5.4901, -0.43128, -1.1759, -0.60466, + -0.82913, -0.73623, 5.4588, 5.4066, 5.2644, -0.89692, -0.16866, + -0.63169, -0.42352, -0.48866, 5.594, 5.5223, 5.5277, -0.17658, + -0.30669, -1.1777, -0.65389, -0.36422, 5.6375, 5.1857, 5.0743, + -0.46654, -0.43817, -0.41506, -0.94515, -0.60247}); + Tensor new_expected_scale = tf_double.make( + {1, 5, 1}, {0.026793, 0.027244, 0.024924, 0.026556, 0.025814}); + Tensor new_expected_zero_point = + tf_long.make({1, 5, 1}, {-88, -85, -92, -84, -91}); + + choose_qparams_per_token_asymmetric_out( + new_input, ScalarType::Float, scale_out, zero_point_out); + + EXPECT_TENSOR_CLOSE_WITH_TOL(scale_out, new_expected_scale, 1e-4, 1e-4); + EXPECT_TENSOR_EQ(zero_point_out, new_expected_zero_point); } diff --git a/kernels/quantized/test/quantized_ops_aot_register.cpp b/kernels/quantized/test/quantized_ops_aot_register.cpp new file mode 100644 index 00000000000..e20f719c1e5 --- /dev/null +++ b/kernels/quantized/test/quantized_ops_aot_register.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +namespace torch { +namespace executor { + +namespace native { + +Tensor& quantize_per_token_out( + RuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + +Tensor& quantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + exec_aten::RuntimeContext context{}; + ::torch::executor::runtime_init(); + quantize_per_token_out( + context, input, scale, zero_point, quant_min, quant_max, dtype, out); + return out; +} + +at::Tensor quantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + c10::ScalarType dtype) { + auto sizes = input.sizes().vec(); + auto output = at::zeros(sizes, dtype); + TORCH_CHECK(dtype == c10::ScalarType::Char, "dtype must be char"); + WRAP_TO_ATEN(quantize_per_token_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, ScalarType::Char, output); + return output; +} + +Tensor& dequantize_per_token_out( + RuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out); + +Tensor& dequantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) { + exec_aten::RuntimeContext context{}; + ::torch::executor::runtime_init(); + dequantize_per_token_out( + context, + input, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + out); + return out; +} + +at::Tensor dequantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + c10::ScalarType dtype, + c10::ScalarType out_dtype) { + auto sizes = input.sizes().vec(); + auto output = at::zeros(sizes, out_dtype); + TORCH_CHECK(dtype == c10::ScalarType::Char, "dtype must be char"); + TORCH_CHECK(out_dtype == c10::ScalarType::Float, "out_dtype must be float"); + WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) + (input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Char, + ScalarType::Float, + output); + return output; +} + +} // namespace native +} // namespace executor +} // namespace torch + +TORCH_LIBRARY(et_quant_test, m) { + m.def( + "quantize_per_token(Tensor input, Tensor scale, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype) -> Tensor"); + m.def( + "dequantize_per_token(Tensor input, Tensor scale, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype, ScalarType out_dtype) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(et_quant_test, CompositeExplicitAutograd, m) { + m.impl( + "quantize_per_token", torch::executor::native::quantize_per_token_aten); + m.impl( + "dequantize_per_token", + torch::executor::native::dequantize_per_token_aten); +} diff --git a/kernels/quantized/test/test_quant_dequant_per_token.py b/kernels/quantized/test/test_quant_dequant_per_token.py new file mode 100644 index 00000000000..8eb7e3e85ac --- /dev/null +++ b/kernels/quantized/test/test_quant_dequant_per_token.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib + + +class QuantizePerTokenTest(unittest.TestCase): + + def setUp(self): + pass + + def test_quantize_per_token(self): + input_tensor = torch.tensor( + [[-0.5, 0.3, 1.2], [0.1, -0.8, 2.1], [-5, 1, 2]], dtype=torch.float32 + ) + scale = torch.tensor([0.5, 0.8, 1.0], dtype=torch.float64) + scale = scale.unsqueeze(-1) + zero_point = torch.tensor([-1, -2, 0]) + zero_point = zero_point.unsqueeze(-1) + quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + + self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) + + def test_quantize_per_token_large_tensor(self): + input_tensor = torch.rand((8, 32)) + scale = torch.rand((8, 1), dtype=torch.float64) + zero_point = torch.randint(0, 10, (8, 1)) + quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + + self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) + + def test_quantize_per_token_high_rank(self): + input_tensor = torch.rand((1, 3, 8, 32)) + scale = torch.rand((1, 3, 8, 1), dtype=torch.float64) + zero_point = torch.randint(0, 10, (1, 3, 8, 1)) + quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + + self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) + + def test_quantize_per_token_dynamic(self): + input_tensor = torch.rand((1, 1, 8, 1)) + scale = torch.rand((1, 1, 8, 1), dtype=torch.float64) + zero_point = torch.randint(0, 10, (1, 1, 8, 1)) + quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + + self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) + + input_tensor = torch.rand((1, 3, 8, 1)) + scale = torch.rand((1, 3, 8, 1), dtype=torch.float64) + zero_point = torch.randint(0, 10, (1, 3, 8, 1)) + quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8 + ) + + self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) + + def test_dequantize_per_token(self): + input_tensor = torch.randint(-50, 120, (3, 3), dtype=torch.int8) + scale = torch.tensor([0.5, 0.8, 1.0], dtype=torch.float64) + scale = scale.unsqueeze(-1) + zero_point = torch.tensor([-1, -2, 0]) + zero_point = zero_point.unsqueeze(-1) + dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + + self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor)) + + def test_dequantize_per_token_large_tensor(self): + input_tensor = torch.randint(-50, 120, (8, 32), dtype=torch.int8) + scale = torch.rand((8, 1), dtype=torch.float64) + zero_point = torch.randint(0, 10, (8, 1)) + dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + + self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor)) + + def test_dequantize_per_token_high_rank(self): + input_tensor = torch.randint(-50, 120, (1, 3, 8, 32), dtype=torch.int8) + scale = torch.rand((1, 3, 8, 1), dtype=torch.float64) + zero_point = torch.randint(0, 10, (1, 3, 8, 1)) + dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + + self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor)) + + def test_dequantize_per_token_dynamic(self): + input_tensor = torch.randint(-50, 120, (1, 1, 8, 32), dtype=torch.int8) + scale = torch.rand((1, 1, 8, 1), dtype=torch.float64) + zero_point = torch.randint(0, 10, (1, 1, 8, 1)) + dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + + self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor)) + + input_tensor = torch.randint(-50, 120, (1, 3, 8, 32), dtype=torch.int8) + scale = torch.rand((1, 3, 8, 1), dtype=torch.float64) + zero_point = torch.randint(0, 10, (1, 3, 8, 1)) + dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( + input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 + ) + + self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor))