From 0e650996826ba735b6f2a08a7dc996b5a296a135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 7 Aug 2025 12:05:51 +0200 Subject: [PATCH 1/2] Arm backend: Propagate node info from quantizer to backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use the Node meta 'custom' field to propagate information from quantizer to partitioner using a new ArmAnnotationInfo data class. This allows us to track quantized node reliably which is useful in order to track which nodes should 'fold' it's quantization parameter and which should be kept in fp when mixing integer and float in a sub-graph. Co-authored-by: Per Åstrand Signed-off-by: Oscar Andersson Change-Id: I398bf52e14d58fce56aa46ace74e45f45050c81b --- backends/arm/common/annotation_meta.py | 39 +++++++ .../tosa_supported_operators.py | 75 +++++++++++++ backends/arm/quantizer/arm_quantizer_utils.py | 10 +- .../arm/quantizer/quantization_annotator.py | 4 +- backends/arm/test/misc/test_int64.py | 6 +- .../arm/test/misc/test_quant_custom_meta.py | 100 ++++++++++++++++++ .../arm/test/misc/test_save_exported_model.py | 62 +++++++++++ .../test_SD3Transformer2DModel.py | 3 +- .../arm/test/models/test_nn_functional.py | 1 - .../arm/test/models/test_torch_functions.py | 4 +- backends/arm/test/ops/test_eye.py | 2 +- backends/arm/tosa/partitioner.py | 45 +++++++- 12 files changed, 335 insertions(+), 16 deletions(-) create mode 100644 backends/arm/common/annotation_meta.py create mode 100644 backends/arm/test/misc/test_quant_custom_meta.py create mode 100644 backends/arm/test/misc/test_save_exported_model.py diff --git a/backends/arm/common/annotation_meta.py b/backends/arm/common/annotation_meta.py new file mode 100644 index 00000000000..12ef80ae70b --- /dev/null +++ b/backends/arm/common/annotation_meta.py @@ -0,0 +1,39 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping, Optional + + +@dataclass(frozen=True, init=False) +class ArmAnnotationInfo(dict): + """ + Dataclass wrapper that behaves like a dict so serialization can treat it as + a plain mapping, while still exposing a typed attribute for convenience. + """ + + quantized: bool + CUSTOM_META_KEY: str = "_arm_annotation_info" + + def __init__( + self, + value: Optional[Mapping[str, Any]] = None, + *, + quantized: Optional[bool] = None, + ) -> None: + if quantized is not None: + resolved = bool(quantized) + + elif isinstance(value, Mapping): + resolved = bool(value.get("quantized", False)) + + else: + raise TypeError( + "ArmAnnotationInfo expects a mapping with a 'quantized' entry or a keyword 'quantized'." + ) + dict.__init__(self, quantized=resolved) + object.__setattr__(self, "quantized", resolved) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index e4050f1dc49..fa0b106c00d 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -27,6 +27,7 @@ FuseQuantizedActivationPass, ) from executorch.backends.arm._passes.insert_table_ops import TableOps +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS from executorch.backends.arm.operator_support.ethos_u55_support import ( EthosU55CastCheck, @@ -209,6 +210,7 @@ def tosa_support_factory( ] if not tosa_spec.support_float(): + negative_checks.append(CheckArmQuantized(reporter)) negative_checks.append(CheckProperQuantization(reporter)) if tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) @@ -260,6 +262,79 @@ def is_node_supported( return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList +class CheckArmQuantized(OperatorSupportBase): + """ + Check if the node was marked as quantized in the Arm backend. + This is used to ensure that nodes that were quantized in the Arm backend + are only partitioned if they are supported by the TOSA backend. + """ + + def __init__(self, reporter: WhyNoPartitionReporter): + self.reporter = reporter + + def _is_quantized(self, node: torch.fx.Node) -> bool: + """Checks if the node is quantized. + + A node is considered quantized if at least one criteria is met: + - Its dtype is not floating point or complex => integer + - It is one of the special cases where the node has been created in to_edge, e.g. + .Scalar operations that have been promoted .Tensor operations + where the scalar is replaced by a full op. + - It has been marked as quantized in the ArmAnnotationInfo custom meta. + + Args: + node (torch.fx.Node): The FX node to check. + + Returns: + bool: True if the node is quantized, False otherwise. + """ + node_dtype = get_first_fake_tensor(node).dtype + if not node_dtype.is_complex and not node_dtype.is_floating_point: + return True + if node.target in ( + exir_ops.edge.aten.full_like.default, + *ComputeConstantOpsAOT.targeted_ops, + ): + # Special cases where nodes have been created in to_edge, e.g. + # .Scalar operations that have been promoted .Tensor operations + # where the scalar is replaced by a full op. + if all(user.target in Q_OPS for user in node.users): + return True + for user in node.users: + if ( + user.target + == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + ): + dim_order_dtype = get_first_fake_tensor(user).dtype + if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point: + return False + else: + return False + return True + return ( + ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) + and ArmAnnotationInfo( + node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY] + ).quantized + ) + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if node.op != "call_function": + return False + + if node.target in (*DQ_OPS, *Q_OPS): + return True + + if not self._is_quantized(node): + self.reporter.report_reject( + node, "Node was not marked as quantized in the Arm backend." + ) + return False + return True + + class CheckProperQuantization(OperatorSupportBase): """Ensure targeted nodes are properly quantized. diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index c1137ea4149..7bd8e00c22b 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,6 +14,8 @@ from typing import cast +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo + from torch.fx import Node from torchao.quantization.pt2e.quantizer import QuantizationAnnotation @@ -65,4 +67,10 @@ def mark_node_as_annotated(node: Node) -> None: """ if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() + annotation_info = ArmAnnotationInfo( + quantized=True, + ) node.meta[Q_ANNOTATION_KEY]._annotated = True + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) + node.meta["custom"] = meta_custom diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index dc3beb5370a..a24a110b1e1 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -394,6 +394,7 @@ def _match_pattern( torch.ops.aten.view.default, torch.ops.aten.view_as.default, torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, torch.ops.aten.select.int, torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, @@ -426,6 +427,7 @@ def _match_pattern( ] _one_to_one_shared_input_or_input_act_qspec = [ + torch.ops.aten.alias.default, torch.ops.aten.clone.default, torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default, @@ -693,10 +695,10 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = None elif node.target in [ - torch.ops.aten.scalar_tensor.default, torch.ops.aten.full.default, torch.ops.aten.full, torch.ops.aten.fill_.Scalar, + torch.ops.aten.scalar_tensor.default, ]: quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) diff --git a/backends/arm/test/misc/test_int64.py b/backends/arm/test/misc/test_int64.py index d6d6d6cb39c..46a97fff1df 100644 --- a/backends/arm/test/misc/test_int64.py +++ b/backends/arm/test/misc/test_int64.py @@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor): ConstAdd(torch.int64, 2**40), (torch.rand(10) - 0.5,), ), - "int64_in+float_const": ( - ConstAdd(torch.float32), - (torch.randint(0, 10, (10,)),), - ), "fp32_in+int64_buffer_chain": ( BufferChainAdd(torch.int64), (torch.rand(2, 5, 3) - 0.5,), @@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple): ArmTester( model, inputs, - common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"), + common.get_tosa_compile_spec("TOSA-1.0+FP"), ) .export() .to_edge_transform_and_lower() diff --git a/backends/arm/test/misc/test_quant_custom_meta.py b/backends/arm/test/misc/test_quant_custom_meta.py new file mode 100644 index 00000000000..d18a1d39e45 --- /dev/null +++ b/backends/arm/test/misc/test_quant_custom_meta.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize + + +class AddSigmoidMul(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, y): + return self.sigmoid(x + y) * x + + +def get_selective_quantizer(modules): + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global(get_symmetric_quantization_config()) + for module in modules: + quantizer.set_module_type(module, None) + + return Quantize(quantizer, get_symmetric_quantization_config()) + + +def test_qdq_squeezed_fp_op(): + """Test that a float operation surrounded by quantize-dequantize pairs + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q + |_____Non-delegated____| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = AddSigmoidMul() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid])) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + }, + ) + pipeline.run() + + +class MulAddSigmoidConv(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + self.conv = torch.nn.Conv1d(3, 3, 1) + + def forward(self, x, y): + return self.conv(self.sigmoid(x + y * x)) + + +def test_quantized_to_float_transition(): + """Test that a model executing quantized ops followed by float ops + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv + |____Non-delegated___| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = MulAddSigmoidConv() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args( + "quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d]) + ) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 1, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + }, + ) + pipeline.run() diff --git a/backends/arm/test/misc/test_save_exported_model.py b/backends/arm/test/misc/test_save_exported_model.py new file mode 100644 index 00000000000..f393fca920c --- /dev/null +++ b/backends/arm/test/misc/test_save_exported_model.py @@ -0,0 +1,62 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.tosa import TosaSpecification +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class SimpleModule(torch.nn.Module): + example_inputs = (torch.randn(1, 10),) + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +def test_save_load_exported_int_model(): + module = SimpleModule().eval() + example_inputs = module.example_inputs + exported_module = torch.export.export(module, example_inputs) + + # Set up quantizer + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global(get_symmetric_quantization_config()) + # Quantize model + prepared_module = prepare_pt2e(exported_module.module(), quantizer) + prepared_module(*example_inputs) + quantized_module = convert_pt2e(prepared_module) + quantized_exported_module = torch.export.export(quantized_module, example_inputs) + + base_path = "arm_test/misc/" + if not os.path.exists(base_path): + os.makedirs(base_path) + file_path = base_path + "exported_module.pt2" + # Verify that we can save the model + torch.export.save(quantized_exported_module, file_path) + + # Verify that we can load the model back + loaded_model = torch.export.load(file_path) + for original_node, loaded_node in zip( + quantized_exported_module.graph.nodes, loaded_model.graph.nodes + ): + # Verify that the custom metadata is preserved after save/load + assert original_node.meta.get("custom", {}) == loaded_node.meta.get( + "custom", {} + ) + if original_node.target == torch.ops.aten.linear.default: + assert ArmAnnotationInfo.CUSTOM_META_KEY in original_node.meta.get( + "custom", {} + ) diff --git a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py index 3e1f19dd39c..6444b8417f2 100644 --- a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py +++ b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py @@ -39,7 +39,8 @@ class TestSD3Transformer2DModel: ops_after_partitioner_INT = { "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, - "torch.ops.higher_order.executorch_call_delegate": 2, + "torch.ops.higher_order.executorch_call_delegate": 3, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, } def _prepare_inputs( diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index 4896074b544..e585e82ad9d 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -102,7 +102,6 @@ def test_nn_functional_FP(test_data): @parametrize( "test_data", module_tests, - {"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"}, ) def test_nn_functional_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 7f9bbdba177..54a9a6ae676 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -126,10 +126,12 @@ def test_torch_fns_FP(test_data): xfails={ "nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). " "Requires dynamic output shape.", + "eye": "ValueError: Failed processing buffer placeholder: aten_arange_start_step_1_pre_computed_common. " + "Is the original torch function supported?", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", }, - strict=False, + strict=True, ) def test_torch_fns_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/ops/test_eye.py b/backends/arm/test/ops/test_eye.py index eef32259c10..5c829acc145 100644 --- a/backends/arm/test/ops/test_eye.py +++ b/backends/arm/test/ops/test_eye.py @@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t): input_data(), EyeAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") + ) pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index fdb4dc62abf..bf84f8e4a2a 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -22,6 +22,7 @@ from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) + from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.tosa_supported_operators import ( @@ -316,7 +317,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: tagged_exported_program=exported_program, partition_tags=partition_tags ) - def ops_to_not_decompose( + def ops_to_not_decompose( # noqa: C901 self, ep: ExportedProgram, ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: @@ -336,17 +337,24 @@ def ops_to_not_decompose( """ ops_to_not_decompose_if_quant_op = { + torch.ops.aten.eye.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.linear.default, + torch.ops.aten.linspace.default, } ops_to_not_decompose_if_fp = { + torch.ops.aten.eye.default, + torch.ops.aten.logit.default, torch.ops.aten.linear.default, + torch.ops.aten.linspace.default, } ops_to_not_decompose_always = { + torch.ops.aten.logit.default, + } + ops_to_not_decompose_if_integer = { torch.ops.aten.eye.default, torch.ops.aten.linspace.default, - torch.ops.aten.logit.default, } def filter_fn(node: torch.fx.Node) -> bool: @@ -400,15 +408,42 @@ def filter_fn(node: torch.fx.Node) -> bool: ): correct_output_quant = True - return correct_input_quant and correct_output_quant + if correct_input_quant and correct_output_quant: + return True + + if node.target in ops_to_not_decompose_if_integer: + # We only want to tag nodes as do_not_decompose if we are sure that + # we can partition them. We partition them if one or more of the + # following is true: + # 1. The node outputs an integer type. + # 2. All users cast the output to an integer type. + + dtype = get_first_fake_tensor(node).dtype + if not dtype.is_floating_point and not dtype.is_complex: + return True + + output_nodes = node.users + for user in output_nodes: + if user.target != torch.ops.aten.to.dtype: + return False + else: + cast_dtype = get_first_fake_tensor(user).dtype + if cast_dtype.is_complex or cast_dtype.is_floating_point: + return False + return True - # By default, do not decompose the operator - return True + if node.target in ops_to_not_decompose_if_fp: + if self.tosa_spec.support_float(): + return True + if node.target in ops_to_not_decompose_always: + return True + return False ops_to_not_decompose = list( ops_to_not_decompose_always | ops_to_not_decompose_if_quant_op | ops_to_not_decompose_if_fp + | ops_to_not_decompose_if_integer ) if not self.tosa_spec.is_U55_subset: From c55735693b2e0ace23e4cd58374f0b5288f78266 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Fri, 14 Nov 2025 07:57:11 +0100 Subject: [PATCH 2/2] Arm backend: Update TARGETS Signed-off-by: Oscar Andersson Change-Id: I84aa43f880e0ea7bbefb30bf7f31fc0e3b362e5a --- backends/arm/TARGETS | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index d58a525d4f4..6e81adfed6f 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -17,11 +17,7 @@ runtime.python_library( ) runtime.python_library( name = "common", - srcs = [ - "common/__init__.py", - "common/debug.py", - "common/type.py", - ], + srcs = glob(["common/*.py"]), deps = [ "fbsource//third-party/tosa_tools:serializer", "//caffe2:torch",