diff --git a/backends/arm/common/annotation_meta.py b/backends/arm/common/annotation_meta.py new file mode 100644 index 00000000000..a857e36bb3f --- /dev/null +++ b/backends/arm/common/annotation_meta.py @@ -0,0 +1,19 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ArmAnnotationInfo: + """ + Data class to carry Arm-specific annotation information through the pipeline. + This is intended to be attached to node.meta['custom'] and propagated + through partitioning and backend stages. As it's propagated through the pipeline, + it's intentionally minimal and only carries whether the node is quantized or not. + """ + + quantized: bool + CUSTOM_META_KEY: str = "_arm_annotation_info" diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 1f8405e8744..551a2c5c18b 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -21,6 +21,7 @@ FuseQuantizedActivationPass, ) from executorch.backends.arm._passes.insert_table_ops import TableOps +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS from executorch.backends.arm.operator_support.ethos_u55_support import ( EthosU55CastCheck, @@ -140,6 +141,7 @@ def tosa_support_factory( ] if not tosa_spec.support_float(): + negative_checks.append(CheckArmQuantized(reporter)) negative_checks.append(CheckProperQuantization(reporter)) if tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) @@ -167,7 +169,6 @@ class TOSAProINTSupportList(OperatorSupportBase): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList @@ -180,10 +181,80 @@ class TOSAProFPSupportList(OperatorSupportBase): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList +class CheckArmQuantized(OperatorSupportBase): + """ + Check if the node was marked as quantized in the Arm backend. + This is used to ensure that nodes that were quantized in the Arm backend + are only partitioned if they are supported by the TOSA backend. + """ + + def __init__(self, reporter: WhyNoPartitionReporter): + self.reporter = reporter + + def _is_quantized(self, node: torch.fx.Node) -> bool: + """Checks if the node is quantized. + + A node is considered quantized if at least one criteria is met: + - Its dtype is not floating point or complex => integer + - It is one of the special cases where the node has been created in to_edge, e.g. + .Scalar operations that have been promoted .Tensor operations + where the scalar is replaced by a full op. + - It has been marked as quantized in the ArmAnnotationInfo custom meta. + + Args: + node (torch.fx.Node): The FX node to check. + + Returns: + bool: True if the node is quantized, False otherwise. + """ + node_dtype = get_first_fake_tensor(node).dtype + if not node_dtype.is_complex and not node_dtype.is_floating_point: + return True + if node.target in ( + exir_ops.edge.aten.full_like.default, + *ComputeConstantOpsAOT.targeted_ops, + ): + # Special cases where nodes have been created in to_edge, e.g. + # .Scalar operations that have been promoted .Tensor operations + # where the scalar is replaced by a full op. + if all(user.target in Q_OPS for user in node.users): + return True + for user in node.users: + if ( + user.target + == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + ): + dim_order_dtype = get_first_fake_tensor(user).dtype + if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point: + return False + else: + return False + return True + return ( + ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) + and node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY].quantized + ) + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if node.op != "call_function": + return False + + if node.target in (*DQ_OPS, *Q_OPS): + return True + + if not self._is_quantized(node): + self.reporter.report_reject( + node, "Node was not marked as quantized in the Arm backend." + ) + return False + return True + + class CheckProperQuantization(OperatorSupportBase): """ For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize @@ -427,7 +498,6 @@ def is_node_supported( class CheckFloat64Inputs(OperatorSupportBase): - def __init__( self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter ): diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index c1137ea4149..d4b4d967fba 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,6 +14,8 @@ from typing import cast +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo + from torch.fx import Node from torchao.quantization.pt2e.quantizer import QuantizationAnnotation @@ -65,4 +67,10 @@ def mark_node_as_annotated(node: Node) -> None: """ if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() + annotation_info = ArmAnnotationInfo( + quantized=True, + ) node.meta[Q_ANNOTATION_KEY]._annotated = True + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = annotation_info + node.meta["custom"] = meta_custom diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index dc3beb5370a..a24a110b1e1 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -394,6 +394,7 @@ def _match_pattern( torch.ops.aten.view.default, torch.ops.aten.view_as.default, torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, torch.ops.aten.select.int, torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, @@ -426,6 +427,7 @@ def _match_pattern( ] _one_to_one_shared_input_or_input_act_qspec = [ + torch.ops.aten.alias.default, torch.ops.aten.clone.default, torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default, @@ -693,10 +695,10 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = None elif node.target in [ - torch.ops.aten.scalar_tensor.default, torch.ops.aten.full.default, torch.ops.aten.full, torch.ops.aten.fill_.Scalar, + torch.ops.aten.scalar_tensor.default, ]: quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) diff --git a/backends/arm/test/misc/test_int64.py b/backends/arm/test/misc/test_int64.py index d6d6d6cb39c..46a97fff1df 100644 --- a/backends/arm/test/misc/test_int64.py +++ b/backends/arm/test/misc/test_int64.py @@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor): ConstAdd(torch.int64, 2**40), (torch.rand(10) - 0.5,), ), - "int64_in+float_const": ( - ConstAdd(torch.float32), - (torch.randint(0, 10, (10,)),), - ), "fp32_in+int64_buffer_chain": ( BufferChainAdd(torch.int64), (torch.rand(2, 5, 3) - 0.5,), @@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple): ArmTester( model, inputs, - common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"), + common.get_tosa_compile_spec("TOSA-1.0+FP"), ) .export() .to_edge_transform_and_lower() diff --git a/backends/arm/test/misc/test_quant_custom_meta.py b/backends/arm/test/misc/test_quant_custom_meta.py new file mode 100644 index 00000000000..d18a1d39e45 --- /dev/null +++ b/backends/arm/test/misc/test_quant_custom_meta.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize + + +class AddSigmoidMul(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, y): + return self.sigmoid(x + y) * x + + +def get_selective_quantizer(modules): + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global(get_symmetric_quantization_config()) + for module in modules: + quantizer.set_module_type(module, None) + + return Quantize(quantizer, get_symmetric_quantization_config()) + + +def test_qdq_squeezed_fp_op(): + """Test that a float operation surrounded by quantize-dequantize pairs + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q + |_____Non-delegated____| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = AddSigmoidMul() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid])) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + }, + ) + pipeline.run() + + +class MulAddSigmoidConv(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + self.conv = torch.nn.Conv1d(3, 3, 1) + + def forward(self, x, y): + return self.conv(self.sigmoid(x + y * x)) + + +def test_quantized_to_float_transition(): + """Test that a model executing quantized ops followed by float ops + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv + |____Non-delegated___| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = MulAddSigmoidConv() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args( + "quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d]) + ) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 1, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + }, + ) + pipeline.run() diff --git a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py index 3e1f19dd39c..6444b8417f2 100644 --- a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py +++ b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py @@ -39,7 +39,8 @@ class TestSD3Transformer2DModel: ops_after_partitioner_INT = { "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, - "torch.ops.higher_order.executorch_call_delegate": 2, + "torch.ops.higher_order.executorch_call_delegate": 3, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, } def _prepare_inputs( diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index 4896074b544..e585e82ad9d 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -102,7 +102,6 @@ def test_nn_functional_FP(test_data): @parametrize( "test_data", module_tests, - {"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"}, ) def test_nn_functional_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 7f9bbdba177..54a9a6ae676 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -126,10 +126,12 @@ def test_torch_fns_FP(test_data): xfails={ "nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). " "Requires dynamic output shape.", + "eye": "ValueError: Failed processing buffer placeholder: aten_arange_start_step_1_pre_computed_common. " + "Is the original torch function supported?", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", }, - strict=False, + strict=True, ) def test_torch_fns_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/ops/test_eye.py b/backends/arm/test/ops/test_eye.py index eef32259c10..5c829acc145 100644 --- a/backends/arm/test/ops/test_eye.py +++ b/backends/arm/test/ops/test_eye.py @@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t): input_data(), EyeAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") + ) pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index fdb4dc62abf..bf84f8e4a2a 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -22,6 +22,7 @@ from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) + from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.tosa_supported_operators import ( @@ -316,7 +317,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: tagged_exported_program=exported_program, partition_tags=partition_tags ) - def ops_to_not_decompose( + def ops_to_not_decompose( # noqa: C901 self, ep: ExportedProgram, ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: @@ -336,17 +337,24 @@ def ops_to_not_decompose( """ ops_to_not_decompose_if_quant_op = { + torch.ops.aten.eye.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.linear.default, + torch.ops.aten.linspace.default, } ops_to_not_decompose_if_fp = { + torch.ops.aten.eye.default, + torch.ops.aten.logit.default, torch.ops.aten.linear.default, + torch.ops.aten.linspace.default, } ops_to_not_decompose_always = { + torch.ops.aten.logit.default, + } + ops_to_not_decompose_if_integer = { torch.ops.aten.eye.default, torch.ops.aten.linspace.default, - torch.ops.aten.logit.default, } def filter_fn(node: torch.fx.Node) -> bool: @@ -400,15 +408,42 @@ def filter_fn(node: torch.fx.Node) -> bool: ): correct_output_quant = True - return correct_input_quant and correct_output_quant + if correct_input_quant and correct_output_quant: + return True + + if node.target in ops_to_not_decompose_if_integer: + # We only want to tag nodes as do_not_decompose if we are sure that + # we can partition them. We partition them if one or more of the + # following is true: + # 1. The node outputs an integer type. + # 2. All users cast the output to an integer type. + + dtype = get_first_fake_tensor(node).dtype + if not dtype.is_floating_point and not dtype.is_complex: + return True + + output_nodes = node.users + for user in output_nodes: + if user.target != torch.ops.aten.to.dtype: + return False + else: + cast_dtype = get_first_fake_tensor(user).dtype + if cast_dtype.is_complex or cast_dtype.is_floating_point: + return False + return True - # By default, do not decompose the operator - return True + if node.target in ops_to_not_decompose_if_fp: + if self.tosa_spec.support_float(): + return True + if node.target in ops_to_not_decompose_always: + return True + return False ops_to_not_decompose = list( ops_to_not_decompose_always | ops_to_not_decompose_if_quant_op | ops_to_not_decompose_if_fp + | ops_to_not_decompose_if_integer ) if not self.tosa_spec.is_U55_subset: