From 8f5940b9abf36a73a87f8f7cd9f8f8f731cfdd7a Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Fri, 29 Aug 2025 16:28:56 -0700 Subject: [PATCH] Add backend-agnostic implementation for dequantize_per_tensor (#13777) Summary: Part of an ongoing task to provide backend-agnostic implementations for all of the cadence custom ops. Reviewed By: hsharma35 Differential Revision: D81266532 --- backends/cadence/aot/ref_implementations.py | 56 +++++++++++++++++++ .../aot/tests/test_ref_implementations.py | 49 +++++++++++++++- 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 10078f09c54..21d1efe6398 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -56,6 +56,62 @@ def quantize_per_tensor( return torch.round(input / scale + zero_point).to(dtype) +@impl(m, "dequantize_per_tensor") +def dequantize_per_tensor( + input_tensor: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Dequantizes an integral tensor to a floating-point tensor. + + Args: + - input (Tensor): input tensor + - scale (float): Quantization scale. Derived from the ratio + between the min/max of the floating-point tensor and the + min/max of the quantized range. + - zero_point (int): The point which represents 0 in the quantized + range. For example, consider the floating point range [-1., 2.] and + quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from + -1. to 2. So, the point that represents 0 in the quantized range should + be 1/3 of the way from [-7, 7]. This ends up being -2 in the integer space. + - quant_min (int): The smallest value in the quantized domain. Unused since scale + is already provided. + - quant_max (int): The largest value in the quantized domain. Unused since scale + is already provided. + - dtype (torch.dtype): The type of the output tensor. Must be a floating point type. + """ + supported_quant_types = [ + torch.int8, + torch.int16, + torch.int32, + torch.uint8, + torch.uint16, + ] + if input_tensor.dtype not in supported_quant_types: + raise ValueError(f"Input dtype must be one of {supported_quant_types}") + supported_dequant_types = [ + torch.float, + torch.float32, + torch.float16, + torch.bfloat16, + ] + if dtype not in supported_dequant_types: + raise ValueError( + f"Unsupported dtype to dequantize to. Supported dtypes must be one of {supported_dequant_types}" + ) + + # Needed to prevent underflow in cases where the zero_point is larger than + # the quantized value. + if not input_tensor.dtype.is_signed: + input_tensor = input_tensor.to(torch.int32) + + return (input_tensor - zero_point).to(dtype) * scale + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index f9d327d8589..538798711f5 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -10,7 +10,10 @@ import torch -from executorch.backends.cadence.aot.ref_implementations import quantize_per_tensor +from executorch.backends.cadence.aot.ref_implementations import ( + dequantize_per_tensor, + quantize_per_tensor, +) from executorch.backends.cadence.aot.typing_stubs import expand @@ -48,3 +51,47 @@ def test_quantize_per_tensor( torch.equal(output, expected_output), f"Values don't match in {name}: got {output}, expected {expected_output}", ) + + @expand( + [ + # Signed quantization ranges + ("signed_range0_int8", 0, -1.0, 2.0, -7, 7, torch.int8, 0.428), + ("signed_range0_int16", 0, -1.0, 2.0, -7, 7, torch.int16, 0.428), + ("signed_range0_int32", 0, -1.0, 2.0, -7, 7, torch.int32, 0.428), + ("signed_range1_int8", -3, -1.0, 5.0, -6, 7, torch.int8, 0.461), + ("signed_range1_int16", -3, -1.0, 5.0, -6, 7, torch.int16, 0.461), + ("signed_range1_int32", -3, -1.0, 5.0, -6, 7, torch.int32, 0.461), + # Unsigned quantization ranges + ("unsigned_range0_uint8", 3, -1.0, 2.0, 0, 7, torch.uint8, 0.428), + ("unsigned_range0_uint16", 3, -1.0, 2.0, 0, 7, torch.uint16, 0.428), + ("unsigned_range1_uint8", 4, -1.0, 5.0, 3, 7, torch.uint8, 0.0), + ("unsigned_range1_uint16", 4, -1.0, 5.0, 3, 7, torch.uint16, 0.0), + ] + ) + def test_dequantize_per_tensor( + self, + name: str, + input_value: int, + f_min: float, + f_max: float, + q_min: int, + q_max: int, + input_dtype: torch.dtype, + expected_value: int, + ) -> None: + input_tensor = torch.tensor([input_value], dtype=input_dtype) + scale = (f_max - f_min) / (q_max - q_min) + zero_point = round(-f_min / scale) + q_min + expected_output = torch.tensor([expected_value], dtype=torch.float32) + + output = dequantize_per_tensor( + input_tensor, scale, zero_point, q_min, q_max, torch.float32 + ) + + self.assertEqual( + output.dtype, expected_output.dtype, f"Dtype mismatch in {name}" + ) + self.assertTrue( + torch.allclose(output, expected_output, rtol=0.001, atol=0.001), + f"Values don't match in {name}: got {output}, expected {expected_output}", + )