From e51326bfbe3f2fce75e78178823c1f1eb31953a4 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Mon, 6 Oct 2025 13:46:12 -0700 Subject: [PATCH 1/4] Fixed assumption on out_shift for quantized linear (#14789) Summary: out shift should be int32 Reviewed By: hsharma35 Differential Revision: D83875670 --- backends/cadence/aot/ref_implementations.py | 4 ++-- .../aot/tests/test_ref_implementations.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 2642340679e..ad1abb3ce4b 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -330,8 +330,8 @@ def variant( if out_shift.numel() != 1: raise ValueError("out_shift must be a scalar") - if out_shift.dtype != torch.int64: - raise ValueError("out_shift must be an int64") + if out_shift.dtype != torch.int32: + raise ValueError("out_shift must be an int32") _out_shift = int(out_shift.item()) _out_multiplier = int(out_multiplier[0].item()) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index f78d2292e7b..d8a79454097 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -172,7 +172,7 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point torch.tensor([[0]], dtype=dtype), # expected_output per_tensor, @@ -197,7 +197,7 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point torch.tensor([[-2, -8]], dtype=dtype), # expected_output per_tensor, @@ -220,7 +220,7 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point torch.tensor([[0, 0]], dtype=dtype), # expected_output per_tensor, @@ -244,7 +244,7 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point torch.tensor( [[[0, -2, -4], [-2, -7, -12]]], dtype=dtype @@ -270,7 +270,7 @@ def test_quantized_add( torch.tensor( [268435456], dtype=torch.int32 ), # out_multiplier (1.0 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 1, # out_zero_point torch.tensor([[1, 1]], dtype=dtype), # expected_output per_tensor, @@ -295,7 +295,7 @@ def test_quantized_add( torch.tensor( [268435456], dtype=torch.int32 ), # out_multiplier (1.0 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 1, # out_zero_point torch.tensor([[1, 1]], dtype=dtype), # expected_output False, @@ -317,7 +317,7 @@ def test_quantized_add( [268435456], dtype=torch.int32 ), # out_multiplier (0.125 * 2^31) torch.tensor( - [1], dtype=torch.int64 + [1], dtype=torch.int32 ), # out_shift (shift=1, doubles the scale) 1, # out_zero_point torch.tensor([[1, 2]], dtype=dtype), # expected_output @@ -339,7 +339,7 @@ def test_quantized_add( [268435456], dtype=torch.int32 ), # out_multiplier (0.125 * 2^31) torch.tensor( - [1], dtype=torch.int64 + [1], dtype=torch.int32 ), # out_shift (shift=1, doubles the scale) 1, # out_zero_point torch.tensor([[1, 2]], dtype=dtype), # expected_output From e13bcf07136b5e7d998d634340ff2b41cf4b5dbf Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Mon, 6 Oct 2025 13:46:12 -0700 Subject: [PATCH 2/4] Replace quantized conv and relu non-tensor variants with per tensor variants Summary: Fix to just call the per tensor variants for quantized conv and quantized relu, since those are the only ones we are supporting. Differential Revision: D83873738 --- backends/cadence/aot/quantizer/fusion_pass.py | 55 +++---------------- backends/cadence/aot/quantizer/patterns.py | 8 +-- 2 files changed, 11 insertions(+), 52 deletions(-) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 0461c03ccb7..c1b1e21299b 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -306,31 +306,6 @@ def get_args_and_kwargs_conv( (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), - {"dtype": torch.int32}, - ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, - ) - - # Create a single element tensor for the weight zero point - weight_zero_point_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], weight_zero_point), - {"dtype": torch.int32}, - ) - - # Create a single element tensor for the bias scale - bias_scale_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], bias_scale), - {"dtype": torch.float32}, - ) - # Make the args and kwargs for the replacement op args = tuple(inputs_inputs + weights_inputs + [bias]) kwargs = { @@ -339,12 +314,12 @@ def get_args_and_kwargs_conv( "dilation": dilation, "groups": groups, "input_zero_point": dequants_inputs[0].args[2], - "weight_zero_point": weight_zero_point_tensor, - "bias_scale": bias_scale_tensor, + "weight_zero_point": weight_zero_point, + "bias_scale": bias_scale, "out_scale": quant_node.args[1], "out_zero_point": quant_node.args[2], - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), } return args, kwargs @@ -365,27 +340,11 @@ def get_args_and_kwargs_relu( # Make the args and kwargs for the replacement op args = tuple(inputs_inputs) - X_zero_point = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[0].args[2]), - {"dtype": torch.int32}, - ) - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), - {"dtype": torch.int32}, - ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, - ) - kwargs = { - "X_zero_point": X_zero_point, + "X_zero_point": dequants_inputs[0].args[2], "out_zero_point": quant_node.args[2], - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), } return args, kwargs diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 4eae55502d7..13f3a42ec7e 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -265,7 +265,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_conv2d_nchw.default + return torch.ops.cadence.quantized_conv2d_nchw.per_tensor class Conv2dPattern(QuantizationPattern): @@ -307,7 +307,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_conv2d_nchw.default + return torch.ops.cadence.quantized_conv2d_nchw.per_tensor class LayerNormPattern(QuantizationPattern): @@ -437,7 +437,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_relu.default + return torch.ops.cadence.quantized_relu.per_tensor # Regular relu op @@ -496,7 +496,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_conv2d_nchw.default + return torch.ops.cadence.quantized_conv2d_nchw.per_tensor # Conv1d + regular relu op fusion From 0a6aba0860284865b859231d3728f89743748c0e Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Mon, 6 Oct 2025 13:46:12 -0700 Subject: [PATCH 3/4] Fix avg_pool2d replace ops pass Summary: The original pass didn't fetch the user-provided zero point if it existed, it just assumed a hard-coded zero point. Fixed now. Reviewed By: ethansfng Differential Revision: D83873937 --- backends/cadence/aot/replace_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 24390da5e16..7025159e443 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1644,7 +1644,7 @@ def call_operator(self, op, args, kwargs, meta): ceil_mode = args[4] if len(args) >= 5 else False count_include_pad = args[5] if len(args) >= 6 else True divisor_override = args[6] if len(args) >= 7 else None - zero_point = torch.tensor(0, dtype=torch.int32) + zero_point = args[7] if len(args) >= 8 else None # If the op is avg_pool1d, then we need to reshape the 3d input to a 4d # tensor. From b59a820c0a8310694221f5afd830dba81fc68a8a Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Mon, 6 Oct 2025 13:46:12 -0700 Subject: [PATCH 4/4] Removed support for non-per-tensor quantized relu (#14788) Summary: Not supporting quantized relu default, so removing it from ref_implementations Differential Revision: D83874866 --- backends/cadence/aot/ref_implementations.py | 49 +++++---------------- 1 file changed, 10 insertions(+), 39 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index ad1abb3ce4b..1d4338af21b 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1125,7 +1125,6 @@ def quantized_relu_common( def quantized_relu_variant( - per_tensor: bool, dtype: torch.dtype | None = None, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: """Create a quantized relu variant with type checking.""" @@ -1133,43 +1132,20 @@ def quantized_relu_variant( def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: def variant( X: torch.Tensor, - X_zero_point: torch.Tensor | int, + X_zero_point: int, out_zero_point: int, - out_multiplier: torch.Tensor | int, - out_shift: torch.Tensor | int, + out_multiplier: int, + out_shift: int, ) -> torch.Tensor: - if per_tensor: - if dtype and X.dtype != dtype: - raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}") - - assert isinstance(out_shift, int) - assert isinstance(out_multiplier, int) - _out_shift = out_shift - _out_multiplier = out_multiplier - else: - assert isinstance(out_multiplier, torch.Tensor) - if out_multiplier.numel() > 1: - raise ValueError("Only scalar out_multiplier is supported") - - assert isinstance(out_shift, torch.Tensor) - if out_shift.numel() > 1: - raise ValueError("Only scalar out_shift is supported") - - assert isinstance(X_zero_point, torch.Tensor) - if X_zero_point.shape != X.shape: - raise ValueError( - f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}" - ) - - _out_multiplier = int(out_multiplier.item()) - _out_shift = int(out_shift.item()) + if dtype and X.dtype != dtype: + raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}") return quantized_relu_common( X, X_zero_point, out_zero_point, - _out_multiplier, - _out_shift, + out_multiplier, + out_shift, ) return variant @@ -1177,23 +1153,18 @@ def variant( return decorator -@impl(m, "quantized_relu") -@quantized_relu_variant(False) -def quantized_relu() -> torch.Tensor: ... - - @impl(m, "quantized_relu.per_tensor") -@quantized_relu_variant(True) +@quantized_relu_variant() def quantized_relu_per_tensor() -> torch.Tensor: ... @impl(m, "quantized_relu_asym8s_asym8s.per_tensor") -@quantized_relu_variant(True, torch.int8) +@quantized_relu_variant(torch.int8) def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ... @impl(m, "quantized_relu_asym8u_asym8u.per_tensor") -@quantized_relu_variant(True, torch.uint8) +@quantized_relu_variant(torch.uint8) def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...