From abbd16ed21e09904743b47d5183ea351959622fb Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 28 Aug 2025 11:48:33 -0700 Subject: [PATCH] Add backend-agnostic implementation for quantize_per_tensor (#13769) Summary: Part of an ongoing task to provide backend-agnostic implementations for all of the cadence custom ops. Differential Revision: D81187339 --- backends/cadence/aot/TARGETS | 14 ++++++ backends/cadence/aot/ref_implementations.py | 36 +++++++++++++ .../aot/tests/test_ref_implementations.py | 50 +++++++++++++++++++ 3 files changed, 100 insertions(+) create mode 100644 backends/cadence/aot/tests/test_ref_implementations.py diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index e257df37c8a..1a2c5a9709f 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -604,3 +604,17 @@ python_unittest( "//later:lib", ], ) + +python_unittest( + name = "test_ref_implementations", + srcs = [ + "tests/test_ref_implementations.py", + ], + supports_static_listing = False, + typing = True, + deps = [ + ":typing_stubs", + "//executorch/backends/cadence/aot:ref_implementations", + "//caffe2:torch", + ] +) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 9eaac004bcf..10078f09c54 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -20,6 +20,42 @@ } +@impl(m, "quantize_per_tensor") +def quantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Quantizes a floating-point tensor to an integral tensor. + + Args: + - input (Tensor): input tensor + - scale (float): Quantization scale. Derived from the ratio + between the min/max of the floating-point tensor and the + min/max of the quantized range. + - zero_point (int): The point which represents 0 in the quantized + range. For example, consider the floating point range [-1., 2.] and + quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from + -1. to 2. So, the point that represents 0 in the quantized range should + be 1/3 of the way from [-7, 7]. This ends up being -2 in the integer space. + - quant_min (int): The smallest value in the quantized domain. Unused since scale + is already provided. + - quant_max (int): The largest value in the quantized domain. Unused since scale + is already provided. + - dtype (torch.dtype): The type of the output tensor + """ + supported_quant_types = [torch.int8, torch.int16, torch.int32] + if dtype not in supported_quant_types: + raise ValueError( + f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" + ) + return torch.round(input / scale + zero_point).to(dtype) + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py new file mode 100644 index 00000000000..f9d327d8589 --- /dev/null +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch + +from executorch.backends.cadence.aot.ref_implementations import quantize_per_tensor +from executorch.backends.cadence.aot.typing_stubs import expand + + +class TestRefImplementations(unittest.TestCase): + @expand( + [ + ("basic_int8", 0.42, -1.0, 2.0, -7, 7, torch.int8, 0), + ("basic_int16", 0.42, -1.0, 5.0, -6, 7, torch.int16, -3), + ] + ) + def test_quantize_per_tensor( + self, + name: str, + input_value: float, + f_min: float, + f_max: float, + q_min: int, + q_max: int, + target_dtype: torch.dtype, + expected_value: int, + ) -> None: + input_tensor = torch.tensor([input_value]) + scale = (f_max - f_min) / (q_max - q_min) + zero_point = round(-f_min / scale) + q_min + expected_output = torch.tensor([expected_value], dtype=target_dtype) + + output = quantize_per_tensor( + input_tensor, scale, zero_point, q_min, q_max, target_dtype + ) + + self.assertEqual( + output.dtype, expected_output.dtype, f"Dtype mismatch in {name}" + ) + self.assertTrue( + torch.equal(output, expected_output), + f"Values don't match in {name}: got {output}, expected {expected_output}", + )