From d32ee2335805734250200de280a16e5ac0f4e06f Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Sun, 14 Sep 2025 11:52:06 -0700 Subject: [PATCH] Add int8/uint8 specialized variants of quantized_add_per_tensor (#14094) Summary: Add type specialized variants of quantized_add_per_tensor Reviewed By: hsharma35 Differential Revision: D81951110 --- backends/cadence/aot/ref_implementations.py | 42 +++++++++++++++++++ .../aot/tests/test_ref_implementations.py | 23 +++++++++- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index f3856d2fd1c..483d8f18241 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -193,6 +193,48 @@ def quantized_add_per_tensor( ) +@impl(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor") +def quantized_add_asym8sxasym8s_asym8s_per_tensor( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + if X.dtype != torch.int8: + raise ValueError("X dtype must be torch.int8") + if Y.dtype != torch.int8: + raise ValueError("Y dtype must be torch.int8") + + return quantized_add_per_tensor( + X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point + ) + + +@impl(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor") +def quantized_add_asym8uxasym8u_asym8u_per_tensor( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + if X.dtype != torch.uint8: + raise ValueError("X dtype must be torch.int8") + if Y.dtype != torch.uint8: + raise ValueError("Y dtype must be torch.int8") + + return quantized_add_per_tensor( + X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point + ) + + def quantized_linear_common( src: torch.Tensor, weight: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 53ed526f759..03de587c3be 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -100,7 +100,7 @@ def test_dequantize_per_tensor( [ # Only these types need to be tested as per ET_FORALL_JARVIS_QUANTIZED_TYPES in # on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h - ("int16", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8), + ("int8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8), ("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.uint8), ] ) @@ -122,6 +122,27 @@ def test_quantized_add( Y_tensor = torch.tensor([Y], dtype=dtype) expected_output = torch.tensor([expected_value], dtype=dtype) + quantized_add = ( + torch.ops.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor + if dtype == torch.int8 + else torch.ops.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor + ) + output = quantized_add( + X_tensor, + X_scale, + X_zero_point, + Y_tensor, + Y_scale, + Y_zero_point, + out_scale, + out_zero_point, + ) + + self.assertTrue( + torch.equal(output, expected_output), + f"Values don't match in {name}: got {output}, expected {expected_output}", + ) + output = torch.ops.cadence.quantized_add( X_tensor, X_scale,