From bb7df529a4abf9353f6cc0c28a8da5c6762e0ca5 Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Thu, 4 Sep 2025 11:29:26 -0700 Subject: [PATCH 1/2] Update backend-agnostic quantize_per_tensor to use scale_inv to match the quantize_per_tensor_out implementation (#13846) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/13846 Fixes an interface mismatch between the quantize_per_tensor_out implementation and the old quantize_per_tensor python implementation Differential Revision: D81459313 --- backends/cadence/aot/ref_implementations.py | 32 +++++++++++++++---- .../aot/tests/test_ref_implementations.py | 5 +-- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index bda8f9425ea..abbfc41889b 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -24,7 +24,7 @@ @impl(m, "quantize_per_tensor") def quantize_per_tensor( - input: torch.Tensor, + input_tensor: torch.Tensor, scale: float, zero_point: int, quant_min: int, @@ -35,10 +35,10 @@ def quantize_per_tensor( Quantizes a floating-point tensor to an integral tensor. Args: - - input (Tensor): input tensor - - scale (float): Quantization scale. Derived from the ratio + - input_tensor (Tensor): input tensor + - scale (float): Inverse of quantization scale. Derived from the ratio between the min/max of the floating-point tensor and the - min/max of the quantized range. + min/max of the quantized range, and then inverted. - zero_point (int): The point which represents 0 in the quantized range. For example, consider the floating point range [-1., 2.] and quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from @@ -61,7 +61,12 @@ def quantize_per_tensor( raise ValueError( f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" ) - return torch.round(input / scale + zero_point).to(dtype) + + dequantized = torch.round(input_tensor * scale + zero_point).to(dtype) + return torch.max( + torch.min(dequantized, torch.tensor(quant_max)), + torch.tensor(quant_min), + ) @impl(m, "dequantize_per_tensor") @@ -173,9 +178,16 @@ def quantized_add( dequant_X = X_scale * (X - X_zero_point) dequant_Y = Y_scale * (Y - Y_zero_point) + out_scale_inv = 1 / out_scale + # q_min/q_max are unused args return quantize_per_tensor( - dequant_X + dequant_Y, out_scale, out_zero_point, -128, 127, dtype + dequant_X + dequant_Y, + out_scale_inv, + out_zero_point, + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + dtype, ) @@ -206,6 +218,7 @@ def quantized_linear( - offset (Tensor): Unused """ out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0]) + out_scale_inv = 1 / out_scale N, K = weight.shape @@ -223,7 +236,12 @@ def quantized_linear( src - in_zero_point, weight - weight_zero_point, bias ) return quantize_per_tensor( - out, out_scale, out_zero_point, -128, 127, dtype + out, + out_scale_inv, + out_zero_point, + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + dtype, ).reshape(*leading_dims, N) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 1b02926b3f8..95e1a374463 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -40,11 +40,12 @@ def test_quantize_per_tensor( ) -> None: input_tensor = torch.tensor([input_value]) scale = (f_max - f_min) / (q_max - q_min) - zero_point = round(-f_min / scale) + q_min + inv_scale = 1.0 / scale + zero_point = round(-f_min * inv_scale) + q_min expected_output = torch.tensor([expected_value], dtype=target_dtype) output = quantize_per_tensor( - input_tensor, scale, zero_point, q_min, q_max, target_dtype + input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype ) self.assertEqual( From 45370f828fb4ae3c56ace1eb87702d3ab3685167 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 4 Sep 2025 11:42:01 -0700 Subject: [PATCH 2/2] Backend-agnostic implementation of quantized_layer_norm_per_tensor (#13847) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/13847 Continuing support for supporting backend-agnostic Cadence custom ops. Reviewed By: hsharma35 Differential Revision: D81459333 --- backends/cadence/aot/ref_implementations.py | 49 ++++++++++ .../aot/tests/test_ref_implementations.py | 95 +++++++++++++++++++ 2 files changed, 144 insertions(+) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index abbfc41889b..5595d74c9c8 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -245,6 +245,55 @@ def quantized_linear( ).reshape(*leading_dims, N) +@impl(m, "quantized_layer_norm_per_tensor") +def quantized_layer_norm_per_tensor( + input_tensor: torch.Tensor, + X_scale: float, + X_zero_point: int, + normalized_shape: int, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + output_scale: float, + output_zero_point: int, +) -> torch.Tensor: + """ + Quantized layer norm operation. + + Args: + - input_tensor (Tensor): The activations tensor + - X_scale (float): The scale of the input + - X_zero_point (int): The zero point of the input + - normalized_shape (int): The shape of the input + - weight (Tensor): The weight tensor + - bias (Tensor): The bias tensor + - eps (float): The epsilon value + - output_scale (float): The scale of the output + - output_zero_point (int): The zero point of the output + """ + supported_dtypes = [torch.int8, torch.uint8] + if input_tensor.dtype not in supported_dtypes: + raise ValueError( + f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}" + ) + + float_input_tensor = dequantize_per_tensor( + input_tensor, X_scale, X_zero_point, -128, 127, torch.float32 + ) + out = torch.nn.functional.layer_norm( + float_input_tensor, (normalized_shape,), weight, bias, eps=eps + ) + + return quantize_per_tensor( + out, + 1 / output_scale, + output_zero_point, + torch.iinfo(input_tensor.dtype).min, + torch.iinfo(input_tensor.dtype).max, + input_tensor.dtype, + ) + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 95e1a374463..52f6c308da8 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -15,6 +15,7 @@ dequantize_per_tensor, quantize_per_tensor, quantized_add, + quantized_layer_norm_per_tensor, quantized_linear, ) from executorch.backends.cadence.aot.typing_stubs import expand @@ -240,3 +241,97 @@ def test_quantized_linear( torch.equal(output, expected_output), f"Values don't match: got {output}, expected {expected_output}", ) + + @expand( + [ + # Test case 1: Simple case with int8, zero mean input + ( + torch.tensor( + [[-1, 1]], dtype=torch.int8 + ), # input: dequantized to [-0.1, 0.1] + 0.1, # X_scale + 0, # X_zero_point + 2, # normalized_shape (last dimension) + torch.tensor([1.0, 1.0]), # weight + torch.tensor([0.0, 0.0]), # bias + 1e-5, # eps + 0.1, # output_scale + 0, # output_zero_point + torch.int8, # dtype + torch.tensor([[-10, 10]], dtype=torch.int8), # expected_output + ), + # Test case 2: uint8 with zero_point offset + ( + torch.tensor( + [[127, 129]], dtype=torch.uint8 + ), # input: dequantized to [-0.05, 0.05] + 0.05, # X_scale + 128, # X_zero_point + 2, # normalized_shape (last dimension) + torch.tensor([1.0, 1.0]), # weight + torch.tensor([0.0, 0.0]), # bias + 1e-5, # eps + 0.05, # output_scale + 128, # output_zero_point + torch.uint8, # dtype + torch.tensor([[108, 148]], dtype=torch.uint8), # expected_output + ), + # Test case 3: Test with weight and bias scaling + ( + torch.tensor( + [[-2, 2]], dtype=torch.int8 + ), # input: dequantized to [-0.2, 0.2] + 0.1, # X_scale + 0, # X_zero_point + 2, # normalized_shape (last dimension) + torch.tensor( + [2.0, 0.5] + ), # weight: scale first element by 2, second by 0.5 + torch.tensor( + [0.1, -0.1] + ), # bias: add 0.1 to first, subtract 0.1 from second + 1e-5, # eps + 0.1, # output_scale + 0, # output_zero_point + torch.int8, # dtype + torch.tensor([[-19, 4]], dtype=torch.int8), # expected_output + ), + ] + ) + def test_quantized_layer_norm_per_tensor( + self, + input_tensor: torch.Tensor, + X_scale: float, + X_zero_point: int, + normalized_shape: int, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + output_scale: float, + output_zero_point: int, + dtype: torch.dtype, + expected_output: torch.Tensor, + ) -> None: + output = quantized_layer_norm_per_tensor( + input_tensor, + X_scale, + X_zero_point, + normalized_shape, + weight, + bias, + eps, + output_scale, + output_zero_point, + ) + + # Verify output properties + self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") + self.assertEqual( + output.shape, input_tensor.shape, "Output shape should match input shape" + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected. Got {output}, expected {expected_output}", + )