From ee4368909acb1df6304cf5ac24c1ede3bb3b4243 Mon Sep 17 00:00:00 2001 From: Sicheng Stephen Jia Date: Tue, 11 Nov 2025 15:06:13 -0500 Subject: [PATCH] Revert "Arm backend: Propagate node info from quantizer to backend (#15300)" This reverts commit 6550a373159abaef27e96d043ccf493c822676e0. --- backends/arm/common/annotation_meta.py | 19 ---- .../tosa_supported_operators.py | 76 +------------ backends/arm/quantizer/arm_quantizer_utils.py | 10 +- .../arm/quantizer/quantization_annotator.py | 4 +- backends/arm/test/misc/test_int64.py | 6 +- .../arm/test/misc/test_quant_custom_meta.py | 100 ------------------ .../test_SD3Transformer2DModel.py | 3 +- .../arm/test/models/test_nn_functional.py | 1 + .../arm/test/models/test_torch_functions.py | 4 +- backends/arm/test/ops/test_eye.py | 2 +- backends/arm/tosa/partitioner.py | 45 +------- 11 files changed, 19 insertions(+), 251 deletions(-) delete mode 100644 backends/arm/common/annotation_meta.py delete mode 100644 backends/arm/test/misc/test_quant_custom_meta.py diff --git a/backends/arm/common/annotation_meta.py b/backends/arm/common/annotation_meta.py deleted file mode 100644 index a857e36bb3f..00000000000 --- a/backends/arm/common/annotation_meta.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass - - -@dataclass(frozen=True) -class ArmAnnotationInfo: - """ - Data class to carry Arm-specific annotation information through the pipeline. - This is intended to be attached to node.meta['custom'] and propagated - through partitioning and backend stages. As it's propagated through the pipeline, - it's intentionally minimal and only carries whether the node is quantized or not. - """ - - quantized: bool - CUSTOM_META_KEY: str = "_arm_annotation_info" diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 551a2c5c18b..1f8405e8744 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -21,7 +21,6 @@ FuseQuantizedActivationPass, ) from executorch.backends.arm._passes.insert_table_ops import TableOps -from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS from executorch.backends.arm.operator_support.ethos_u55_support import ( EthosU55CastCheck, @@ -141,7 +140,6 @@ def tosa_support_factory( ] if not tosa_spec.support_float(): - negative_checks.append(CheckArmQuantized(reporter)) negative_checks.append(CheckProperQuantization(reporter)) if tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) @@ -169,6 +167,7 @@ class TOSAProINTSupportList(OperatorSupportBase): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList @@ -181,78 +180,8 @@ class TOSAProFPSupportList(OperatorSupportBase): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList - - -class CheckArmQuantized(OperatorSupportBase): - """ - Check if the node was marked as quantized in the Arm backend. - This is used to ensure that nodes that were quantized in the Arm backend - are only partitioned if they are supported by the TOSA backend. - """ - - def __init__(self, reporter: WhyNoPartitionReporter): - self.reporter = reporter - - def _is_quantized(self, node: torch.fx.Node) -> bool: - """Checks if the node is quantized. - A node is considered quantized if at least one criteria is met: - - Its dtype is not floating point or complex => integer - - It is one of the special cases where the node has been created in to_edge, e.g. - .Scalar operations that have been promoted .Tensor operations - where the scalar is replaced by a full op. - - It has been marked as quantized in the ArmAnnotationInfo custom meta. - - Args: - node (torch.fx.Node): The FX node to check. - - Returns: - bool: True if the node is quantized, False otherwise. - """ - node_dtype = get_first_fake_tensor(node).dtype - if not node_dtype.is_complex and not node_dtype.is_floating_point: - return True - if node.target in ( - exir_ops.edge.aten.full_like.default, - *ComputeConstantOpsAOT.targeted_ops, - ): - # Special cases where nodes have been created in to_edge, e.g. - # .Scalar operations that have been promoted .Tensor operations - # where the scalar is replaced by a full op. - if all(user.target in Q_OPS for user in node.users): - return True - for user in node.users: - if ( - user.target - == exir_ops.edge.dim_order_ops._to_dim_order_copy.default - ): - dim_order_dtype = get_first_fake_tensor(user).dtype - if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point: - return False - else: - return False - return True - return ( - ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) - and node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY].quantized - ) - - def is_node_supported( - self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node - ) -> bool: - if node.op != "call_function": - return False - - if node.target in (*DQ_OPS, *Q_OPS): - return True - - if not self._is_quantized(node): - self.reporter.report_reject( - node, "Node was not marked as quantized in the Arm backend." - ) - return False - return True + return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList class CheckProperQuantization(OperatorSupportBase): @@ -498,6 +427,7 @@ def is_node_supported( class CheckFloat64Inputs(OperatorSupportBase): + def __init__( self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter ): diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index d4b4d967fba..c1137ea4149 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. # Copyright 2024-2025 Arm Limited and/or its affiliates. +# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,8 +14,6 @@ from typing import cast -from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo - from torch.fx import Node from torchao.quantization.pt2e.quantizer import QuantizationAnnotation @@ -67,10 +65,4 @@ def mark_node_as_annotated(node: Node) -> None: """ if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() - annotation_info = ArmAnnotationInfo( - quantized=True, - ) node.meta[Q_ANNOTATION_KEY]._annotated = True - meta_custom = node.meta.get("custom", {}) - meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = annotation_info - node.meta["custom"] = meta_custom diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index a24a110b1e1..dc3beb5370a 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -394,7 +394,6 @@ def _match_pattern( torch.ops.aten.view.default, torch.ops.aten.view_as.default, torch.ops.aten.view_copy.default, - torch.ops.aten._unsafe_view.default, torch.ops.aten.select.int, torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, @@ -427,7 +426,6 @@ def _match_pattern( ] _one_to_one_shared_input_or_input_act_qspec = [ - torch.ops.aten.alias.default, torch.ops.aten.clone.default, torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default, @@ -695,10 +693,10 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = None elif node.target in [ + torch.ops.aten.scalar_tensor.default, torch.ops.aten.full.default, torch.ops.aten.full, torch.ops.aten.fill_.Scalar, - torch.ops.aten.scalar_tensor.default, ]: quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) diff --git a/backends/arm/test/misc/test_int64.py b/backends/arm/test/misc/test_int64.py index 46a97fff1df..d6d6d6cb39c 100644 --- a/backends/arm/test/misc/test_int64.py +++ b/backends/arm/test/misc/test_int64.py @@ -68,6 +68,10 @@ def forward(self, x: torch.Tensor): ConstAdd(torch.int64, 2**40), (torch.rand(10) - 0.5,), ), + "int64_in+float_const": ( + ConstAdd(torch.float32), + (torch.randint(0, 10, (10,)),), + ), "fp32_in+int64_buffer_chain": ( BufferChainAdd(torch.int64), (torch.rand(2, 5, 3) - 0.5,), @@ -90,7 +94,7 @@ def test_int64_tosa_FP(test_data: Tuple): ArmTester( model, inputs, - common.get_tosa_compile_spec("TOSA-1.0+FP"), + common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"), ) .export() .to_edge_transform_and_lower() diff --git a/backends/arm/test/misc/test_quant_custom_meta.py b/backends/arm/test/misc/test_quant_custom_meta.py deleted file mode 100644 index d18a1d39e45..00000000000 --- a/backends/arm/test/misc/test_quant_custom_meta.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from executorch.backends.arm.quantizer import ( - get_symmetric_quantization_config, - TOSAQuantizer, -) -from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.xnnpack.test.tester import Quantize - - -class AddSigmoidMul(torch.nn.Module): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x, y): - return self.sigmoid(x + y) * x - - -def get_selective_quantizer(modules): - quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) - quantizer.set_global(get_symmetric_quantization_config()) - for module in modules: - quantizer.set_module_type(module, None) - - return Quantize(quantizer, get_symmetric_quantization_config()) - - -def test_qdq_squeezed_fp_op(): - """Test that a float operation surrounded by quantize-dequantize pairs - is correctly handled by the partitioner and the TOSA backend. - Pattern: - q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q - |_____Non-delegated____| - """ - aten_op = "torch.ops.aten.add.Tensor" - exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" - module = AddSigmoidMul() - x = torch.randn(2, 3, 4) - y = torch.randn(2, 3, 4) - pipeline = TosaPipelineINT( - module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op - ) - pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid])) - pipeline.change_args( - "check_count.exir", - { - "torch.ops.higher_order.executorch_call_delegate": 2, - "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, - }, - ) - pipeline.run() - - -class MulAddSigmoidConv(torch.nn.Module): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.sigmoid = torch.nn.Sigmoid() - self.conv = torch.nn.Conv1d(3, 3, 1) - - def forward(self, x, y): - return self.conv(self.sigmoid(x + y * x)) - - -def test_quantized_to_float_transition(): - """Test that a model executing quantized ops followed by float ops - is correctly handled by the partitioner and the TOSA backend. - Pattern: - q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv - |____Non-delegated___| - """ - aten_op = "torch.ops.aten.add.Tensor" - exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" - module = MulAddSigmoidConv() - x = torch.randn(2, 3, 4) - y = torch.randn(2, 3, 4) - pipeline = TosaPipelineINT( - module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op - ) - pipeline.change_args( - "quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d]) - ) - pipeline.change_args( - "check_count.exir", - { - "torch.ops.higher_order.executorch_call_delegate": 1, - "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, - "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - }, - ) - pipeline.run() diff --git a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py index 6444b8417f2..3e1f19dd39c 100644 --- a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py +++ b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py @@ -39,8 +39,7 @@ class TestSD3Transformer2DModel: ops_after_partitioner_INT = { "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, - "torch.ops.higher_order.executorch_call_delegate": 3, - "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "torch.ops.higher_order.executorch_call_delegate": 2, } def _prepare_inputs( diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index e585e82ad9d..4896074b544 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -102,6 +102,7 @@ def test_nn_functional_FP(test_data): @parametrize( "test_data", module_tests, + {"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"}, ) def test_nn_functional_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 54a9a6ae676..7f9bbdba177 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -126,12 +126,10 @@ def test_torch_fns_FP(test_data): xfails={ "nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). " "Requires dynamic output shape.", - "eye": "ValueError: Failed processing buffer placeholder: aten_arange_start_step_1_pre_computed_common. " - "Is the original torch function supported?", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", }, - strict=True, + strict=False, ) def test_torch_fns_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/ops/test_eye.py b/backends/arm/test/ops/test_eye.py index 5c829acc145..eef32259c10 100644 --- a/backends/arm/test/ops/test_eye.py +++ b/backends/arm/test/ops/test_eye.py @@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t): input_data(), EyeAdd.aten_op, use_to_edge_transform_and_lower=True, - ) + ).dump_artifact("to_edge_transform_and_lower") pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index bf84f8e4a2a..fdb4dc62abf 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -22,7 +22,6 @@ from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) - from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.tosa_supported_operators import ( @@ -317,7 +316,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: tagged_exported_program=exported_program, partition_tags=partition_tags ) - def ops_to_not_decompose( # noqa: C901 + def ops_to_not_decompose( self, ep: ExportedProgram, ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: @@ -337,24 +336,17 @@ def ops_to_not_decompose( # noqa: C901 """ ops_to_not_decompose_if_quant_op = { - torch.ops.aten.eye.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.linear.default, - torch.ops.aten.linspace.default, } ops_to_not_decompose_if_fp = { - torch.ops.aten.eye.default, - torch.ops.aten.logit.default, torch.ops.aten.linear.default, - torch.ops.aten.linspace.default, } ops_to_not_decompose_always = { - torch.ops.aten.logit.default, - } - ops_to_not_decompose_if_integer = { torch.ops.aten.eye.default, torch.ops.aten.linspace.default, + torch.ops.aten.logit.default, } def filter_fn(node: torch.fx.Node) -> bool: @@ -408,42 +400,15 @@ def filter_fn(node: torch.fx.Node) -> bool: ): correct_output_quant = True - if correct_input_quant and correct_output_quant: - return True - - if node.target in ops_to_not_decompose_if_integer: - # We only want to tag nodes as do_not_decompose if we are sure that - # we can partition them. We partition them if one or more of the - # following is true: - # 1. The node outputs an integer type. - # 2. All users cast the output to an integer type. - - dtype = get_first_fake_tensor(node).dtype - if not dtype.is_floating_point and not dtype.is_complex: - return True - - output_nodes = node.users - for user in output_nodes: - if user.target != torch.ops.aten.to.dtype: - return False - else: - cast_dtype = get_first_fake_tensor(user).dtype - if cast_dtype.is_complex or cast_dtype.is_floating_point: - return False - return True + return correct_input_quant and correct_output_quant - if node.target in ops_to_not_decompose_if_fp: - if self.tosa_spec.support_float(): - return True - if node.target in ops_to_not_decompose_always: - return True - return False + # By default, do not decompose the operator + return True ops_to_not_decompose = list( ops_to_not_decompose_always | ops_to_not_decompose_if_quant_op | ops_to_not_decompose_if_fp - | ops_to_not_decompose_if_integer ) if not self.tosa_spec.is_U55_subset: