diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 28d70591e5e..331d45e9124 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -27,6 +27,7 @@ from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found] ConvertSqueezesToViewPass, ) +from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass from executorch.backends.arm._passes.decompose_batchnorm_pass import ( DecomposeBatchNormPass, ) @@ -104,6 +105,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeLinearPass()) self.add_pass(ConvertMeanDimToAveragePoolPass()) self.add_pass(ConvertFullLikeToFullPass()) + self.add_pass(ConvertToClampPass()) self.add_pass(ReplaceScalarWithTensorArgPass()) self.add_pass(AnnotateDecomposedMatmulPass()) @@ -144,6 +146,8 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(ConvertFullLikeToFullPass()) + self.add_pass(ConvertToClampPass()) + self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] diff --git a/backends/arm/_passes/convert_to_clamp.py b/backends/arm/_passes/convert_to_clamp.py new file mode 100644 index 00000000000..8f2c9b16f9a --- /dev/null +++ b/backends/arm/_passes/convert_to_clamp.py @@ -0,0 +1,36 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_operators = { + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.relu.default, +} + + +def get_clamp_params(op, args) -> Tuple[float | None, float | None]: + if op == exir_ops.edge.aten.hardtanh.default: + return args[1], args[2] + elif op == exir_ops.edge.aten.relu.default: + return 0.0, None + else: + raise ValueError(f"Getting clamp parameters for op {op} is not implemented.") + + +class ConvertToClampPass(ExportPass): + def call_operator(self, op, args, kwargs, meta): + if op not in edge_operators: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.aten.clamp.default, + (args[0], *get_clamp_params(op, args)), + {}, + meta, + ) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 735debe367f..6c6f3ab32f5 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -20,7 +20,6 @@ op_ge, op_get_item, op_gt, - op_hardtanh, op_le, op_log, op_lt, @@ -30,7 +29,6 @@ op_mul, op_permute, op_reciprocal, - op_relu, op_repeat, op_rescale, op_rshift, diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py deleted file mode 100644 index fc0ee552a9f..00000000000 --- a/backends/arm/operators/op_hardtanh.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts # type: ignore -import torch - -# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg - -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class HardTanhVisitor(NodeVisitor): - target = "aten.hardtanh.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - attr = ts.TosaSerializerAttribute() - - if inputs[0].dtype == ts.DType.INT8: - # Get quant parameters - input_qparams = get_input_qparams(node) # pyre-ignore[16] - qargs = input_qparams[0] - # Convert to quantized representation - clamp_min_qs = qargs.quantize_value(inputs[1].number).item() - clamp_max_qs = qargs.quantize_value(inputs[2].number).item() - # Set fp values to 0.0 since they are not used - clamp_min_fp = 0.0 - clamp_max_fp = 0.0 - else: - clamp_min_fp = inputs[1].number - clamp_max_fp = inputs[2].number - # Set qs values to 0 since they are not used - clamp_min_qs = 0 - clamp_max_qs = 0 - - attr.ClampAttribute( - tosa_graph.builder, - clamp_min_qs, - clamp_max_qs, - clamp_min_fp, - clamp_max_fp, - ) - - tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py deleted file mode 100644 index c37e4b3e75d..00000000000 --- a/backends/arm/operators/op_relu.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import serializer.tosa_serializer as ts # type: ignore -import torch.fx - -# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_output_qparams, -) -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class ReluVisitor(NodeVisitor): - target = "aten.relu.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: list[TosaArg], - output: TosaArg, - ) -> None: - attr = ts.TosaSerializerAttribute() - - clamp_min_fp = 0.0 - clamp_max_fp = 0.0 - clamp_min_qs = 0 - clamp_max_qs = 0 - if inputs[0].dtype == ts.DType.INT8: - out_qargs = get_output_qparams(node) # pyre-ignore[16] - clamp_min_qs = out_qargs[0].quantize_value(0).item() - clamp_max_qs = out_qargs[0].quantize_value(float("inf")).item() - else: - clamp_min_fp = 0 - clamp_max_fp = float("inf") - - attr.ClampAttribute( - tosa_graph.builder, - clamp_min_qs, - clamp_max_qs, - clamp_min_fp, - clamp_max_fp, - ) - - tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr) diff --git a/backends/arm/test/passes/test_convert_to_clamp.py b/backends/arm/test/passes/test_convert_to_clamp.py new file mode 100644 index 00000000000..0b106b7bc82 --- /dev/null +++ b/backends/arm/test/passes/test_convert_to_clamp.py @@ -0,0 +1,80 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from executorch.backends.xnnpack.test.tester.tester import RunPasses + + +class HardTanh(torch.nn.Module): + def __init__(self): + super().__init__() + + self.hardtanh = torch.nn.Hardtanh() + + def forward(self, x): + return self.hardtanh(x) + + def get_inputs(self): + return (torch.rand(1, 64, 64, 3),) + + +class ReLU(torch.nn.Module): + def __init__(self): + super().__init__() + + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(x) + + def get_inputs(self): + return (torch.rand(1, 64, 64, 3),) + + +class TestConvertToClampPass(unittest.TestCase): + """ + Tests the ConvertToClampPass which converts hardtanh.default and relu.default to clamp.default + """ + + def test_tosa_MI_hardtahn(self): + module = HardTanh() + test_pass_stage = RunPasses([ConvertToClampPass]) + ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .to_edge() + .check(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) + .run_passes(test_pass_stage) + .check(["executorch_exir_dialects_edge__ops_aten_clamp_default"]) + .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) + ) + + def test_tosa_MI_relu(self): + module = ReLU() + test_pass_stage = RunPasses([ConvertToClampPass]) + ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .to_edge() + .check(["executorch_exir_dialects_edge__ops_aten_relu_default"]) + .run_passes(test_pass_stage) + .check(["executorch_exir_dialects_edge__ops_aten_clamp_default"]) + .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) + )