From df441f534c86477eb8bd90e7a8c32c80f02336f9 Mon Sep 17 00:00:00 2001 From: Tom Allsop Date: Mon, 6 Jan 2025 10:47:22 +0000 Subject: [PATCH 1/2] Add clamp operator to Arm backend * Add support for aten.clamp * Amend QuantizeFullArgument pass to include quantization of clamp arguments Signed-off-by: Tom Allsop Change-Id: I432f4ec60facc50fe45ca05c98308924d6e18109 --- backends/arm/_passes/arm_pass_manager.py | 6 +- .../fold_qdq_with_annotated_qparams_pass.py | 46 +++-- .../tosa_supported_operators.py | 1 + backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_clamp.py | 140 +++++++++++++++ .../arm/quantizer/quantization_annotator.py | 2 + backends/arm/test/ops/test_clamp.py | 165 ++++++++++++++++++ 7 files changed, 346 insertions(+), 15 deletions(-) create mode 100644 backends/arm/operators/op_clamp.py create mode 100644 backends/arm/test/ops/test_clamp.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 9bac3b037cd..71e38d5120a 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -37,7 +37,7 @@ from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, - QuantizeFullArgument, + QuantizeOperatorArguments, RetraceFoldedDtypesPass, ) from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( @@ -88,7 +88,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertMeanDimToAveragePoolPass()) self.add_pass(AnnotateDecomposedMatmulPass()) - self.add_pass(QuantizeFullArgument()) + self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(InsertTableOpsPass(exported_program)) @@ -124,7 +124,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(AnnotateDecomposedMatmulPass()) - self.add_pass(QuantizeFullArgument()) + self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(InsertTableOpsPass(exported_program)) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index b1e680b7bca..8c442566819 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -182,11 +182,14 @@ def call(self, graph_module: GraphModule) -> PassResult: return PassResult(graph_module, True) -class QuantizeFullArgument(ExportPass): +class QuantizeOperatorArguments(ExportPass): """ - Make sure the fill_value for full.default is quantized. This pass needs to be run before - the folding pass above to make sure that the retraced output of the full.default op is - the right dtype. + This pass makes sure that the arguments to full.default and clamp.default are quantized correctly. + More specifically, this pass: + - Makes sure the fill_value for full.default is quantized. This pass needs to be run before + the folding pass above to make sure that the retraced output of the full.default op is + the right dtype. + - Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator. """ def call(self, graph_module: GraphModule) -> PassResult: @@ -194,7 +197,10 @@ def call(self, graph_module: GraphModule) -> PassResult: # Loop over the graph nodes and find full.default nodes. for n in graph_module.graph.nodes: n = cast(Node, n) - if n.target != exir_ops.edge.aten.full.default: + if n.target not in { + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.full.default, + }: continue # Make sure we have a quantized operator @@ -203,13 +209,29 @@ def call(self, graph_module: GraphModule) -> PassResult: continue qargs = QuantArgs.from_operator(user.target, user.args) - if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype: - # replace the node arg with a quantized dito and also set dtype - # to get the right output according to the Edge IR specification: - # exir/dialects/edge/edge.yaml:3596 - quantized_full_value = qargs.quantize_value(n.args[1]).item() - n.update_arg(1, quantized_full_value) - n.update_kwarg("dtype", qargs.dtype) + + if n.target == exir_ops.edge.aten.full.default: + if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype: + # replace the node arg with a quantized dito and also set dtype + # to get the right output according to the Edge IR specification: + # exir/dialects/edge/edge.yaml:3596 + quantized_full_value = qargs.quantize_value(n.args[1]).item() + n.update_arg(1, quantized_full_value) + n.update_kwarg("dtype", qargs.dtype) + modified = True + elif n.target == exir_ops.edge.aten.clamp.default: + # Quantize the min and max arguments of clamp, if they are not None + min_val = n.args[1] + max_val = None if len(n.args) <= 2 else n.args[2] + + if min_val is not None: + quantized_min_val = qargs.quantize_value(min_val).item() + n.update_arg(1, quantized_min_val) + + if max_val is not None: + quantized_max_val = qargs.quantize_value(max_val).item() + n.update_arg(2, quantized_max_val) + modified = True return PassResult(graph_module, modified) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index c3102a86a48..7d86269b9c9 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -76,6 +76,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.expand_copy.default, exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.clamp.default, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.permute_copy.default, exir_ops.edge.aten.hardtanh.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 5a97d33304e..aece200047d 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -12,6 +12,7 @@ op_batch_norm, op_bmm, op_cat, + op_clamp, op_conv2d, op_eq, op_exp, diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py new file mode 100644 index 00000000000..308d6cf8e1d --- /dev/null +++ b/backends/arm/operators/op_clamp.py @@ -0,0 +1,140 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree + +from numbers import Number +from typing import List, Tuple + +import serializer.tosa_serializer as ts + +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) + +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class ClampVisitor_080_BI(NodeVisitor): + target = "aten.clamp.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def _create_clamp_node( + self, + tosa_graph: ts.TosaSerializer, + input_name: str, + output_name: str, + min_int: int, + max_int: int, + min_fp32: float, + max_fp32: float, + ): + attr = ts.TosaSerializerAttribute() + attr.ClampAttribute( + tosa_graph.builder, + min_int, + max_int, + min_fp32, + max_fp32, + ) + tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr) + + def _get_min_max_arguments( + self, + node: Node, + dtype_min: Number, + dtype_max: Number, + ) -> Tuple[Number, Number]: + assert 2 <= len(node.args) <= 3 + + min_arg = dtype_min + max_arg = dtype_max + + if node.args[1] is not None: + min_arg = node.args[1] + + if len(node.args) > 2: + if node.args[2] is not None: + max_arg = node.args[2] + + return min_arg, max_arg + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + assert len(node.all_input_nodes) == 1 + + min_int8, max_int8 = self._get_min_max_arguments( + node, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + ) + + # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments + self._create_clamp_node( + tosa_graph, + inputs[0].name, + output.name, + min_int8, + max_int8, + 0, + 0, + ) + + +@register_node_visitor +class ClampVisitor_080_MI(ClampVisitor_080_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + assert len(node.all_input_nodes) == 1 + + if inputs[0].dtype == ts.DType.INT8: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output) + else: + min_fp32, max_fp32 = self._get_min_max_arguments( + node, + torch.finfo(torch.float32).min, + torch.finfo(torch.float32).max, + ) + + self._create_clamp_node( + tosa_graph, + inputs[0].name, + output.name, + 0, + 0, + min_fp32, + max_fp32, + ) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index f2a124f2790..acb421ba4e4 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -186,6 +186,8 @@ def _match_pattern( torch.ops.aten.full.default, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp.Tensor, operator.getitem, ] diff --git a/backends/arm/test/ops/test_clamp.py b/backends/arm/test/ops/test_clamp.py new file mode 100644 index 00000000000..5cf333068ca --- /dev/null +++ b/backends/arm/test/ops/test_clamp.py @@ -0,0 +1,165 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from numbers import Number +from typing import Tuple, Union + +import pytest +import torch + +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.xnnpack.test.tester.tester import Quantize +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +test_data_suite = [ + # (test_name, test_data, min, max) + ("rank_1", torch.rand(10) * 2, -1.0, 1.0), + ("rank_2", torch.rand(1, 35), 0.5, 0.8), + ("rank_3", torch.ones(1, 10, 10), -1, -1), + ("rank_4", torch.rand(1, 10, 10, 1) * 2, -0.1, 2.0), + ("rank_4_mixed_min_max_dtype", torch.rand(1, 10, 10, 5) + 10, 8.0, 10), + ("rank_4_no_min", torch.rand(1, 10, 10, 1) * 10, None, 5), + ("rank_4_no_max", torch.rand(1, 10, 10, 1) - 3, -3.3, None), +] + + +class TestClamp(unittest.TestCase): + """Tests Clamp Operator.""" + + class Clamp(torch.nn.Module): + def __init__( + self, + min: Union[torch.Tensor, Number, None], + max: Union[torch.Tensor, Number, None], + ): + super().__init__() + + self.clamp_min = min + self.clamp_max = max + + def forward(self, x): + return torch.clamp(x, self.clamp_min, self.clamp_max) + + def _test_clamp_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .check(["torch.ops.aten.clamp.default"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_clamp_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI") + compile_spec = common.get_tosa_compile_spec(tosa_spec) + quantizer = ArmQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.clamp.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_clamp_tosa_ethos_BI_pipeline( + self, + compile_spec: list[CompileSpec], + module: torch.nn.Module, + test_data: Tuple[torch.tensor], + ): + tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec) + quantizer = ArmQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.clamp.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + ) + if conftest.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) + + @parameterized.expand(test_data_suite) + def test_clamp_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + min: Union[torch.Tensor, Number, None], + max: Union[torch.Tensor, Number, None], + ): + self._test_clamp_tosa_MI_pipeline(self.Clamp(min, max), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_clamp_tosa_BI( + self, + test_name: str, + test_data: torch.Tensor, + min: Union[torch.Tensor, Number, None], + max: Union[torch.Tensor, Number, None], + ): + self._test_clamp_tosa_BI_pipeline(self.Clamp(min, max), (test_data,)) + + @parameterized.expand(test_data_suite) + @pytest.mark.corstone_fvp + def test_clamp_tosa_u55_BI( + self, + test_name: str, + test_data: torch.Tensor, + min: Union[torch.Tensor, Number, None], + max: Union[torch.Tensor, Number, None], + ): + self._test_clamp_tosa_ethos_BI_pipeline( + common.get_u55_compile_spec(), self.Clamp(min, max), (test_data,) + ) + + @parameterized.expand(test_data_suite) + @pytest.mark.corstone_fvp + def test_clamp_tosa_u85_BI( + self, + test_name: str, + test_data: torch.Tensor, + min: Union[torch.Tensor, Number, None], + max: Union[torch.Tensor, Number, None], + ): + self._test_clamp_tosa_ethos_BI_pipeline( + common.get_u85_compile_spec(), self.Clamp(min, max), (test_data,) + ) From a39fbd534bf7925f33ff70bd39b36a241f4d0ca0 Mon Sep 17 00:00:00 2001 From: Tom Allsop Date: Mon, 6 Jan 2025 10:47:22 +0000 Subject: [PATCH 2/2] Address merge conflicts and typing issues Signed-off-by: Tom Allsop Change-Id: If9f6bbb9d23f00b9cea6d00c6aa97a6d5d3e77ab --- backends/arm/_passes/arm_pass_manager.py | 4 ++-- backends/arm/operators/op_clamp.py | 30 ++++++++++++++---------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 71e38d5120a..a501fed8f74 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -89,7 +89,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) - self.add_pass(FoldAndAnnotateQParamsPass()) + self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(InsertTableOpsPass(exported_program)) @@ -125,7 +125,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) - self.add_pass(FoldAndAnnotateQParamsPass()) + self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(InsertTableOpsPass(exported_program)) diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 308d6cf8e1d..486da27c9a3 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -4,10 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree -from numbers import Number -from typing import List, Tuple +from typing import Any, List, Tuple -import serializer.tosa_serializer as ts +import serializer.tosa_serializer as ts # type: ignore import torch from executorch.backends.arm.operators.node_visitor import ( @@ -41,7 +40,7 @@ def _create_clamp_node( max_int: int, min_fp32: float, max_fp32: float, - ): + ) -> None: attr = ts.TosaSerializerAttribute() attr.ClampAttribute( tosa_graph.builder, @@ -53,22 +52,27 @@ def _create_clamp_node( tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr) def _get_min_max_arguments( - self, - node: Node, - dtype_min: Number, - dtype_max: Number, - ) -> Tuple[Number, Number]: + self, node: Node, dtype_min: int | float, dtype_max: int | float + ) -> Tuple[int | float, int | float]: + + def cast_type(value: Any) -> int | float: + if isinstance(value, int): + return value + else: + # Attempt to cast to float + return float(value) + assert 2 <= len(node.args) <= 3 min_arg = dtype_min max_arg = dtype_max if node.args[1] is not None: - min_arg = node.args[1] + min_arg = cast_type(node.args[1]) if len(node.args) > 2: if node.args[2] is not None: - max_arg = node.args[2] + max_arg = cast_type(node.args[2]) return min_arg, max_arg @@ -92,8 +96,8 @@ def define_node( tosa_graph, inputs[0].name, output.name, - min_int8, - max_int8, + int(min_int8), + int(max_int8), 0, 0, )