diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 4d2449f946c..bd2437ea377 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -14,6 +14,7 @@ from .cast_to_int32_pass import CastToInt32Pass # noqa from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa +from .convert_elu_params import ConvertELUParamsPass # noqa from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa @@ -34,6 +35,7 @@ from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa from .decompose_cumsum_pass import DecomposeCumsumPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa +from .decompose_elu_pass import DecomposeEluPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa from .decompose_expm1_pass import DecomposeExpm1Pass # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 7592be1d7da..94aac1dd615 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -18,6 +18,7 @@ ComputeConstantOpsAOT, Conv1dUnsqueezePass, ConvertAnyDefaultDimDimsPass, + ConvertELUParamsPass, ConvertExpandCopyToRepeatPass, ConvertFullLikeToFullPass, ConvertIntPowToMuls, @@ -39,6 +40,7 @@ DecomposeCosineSimilarityPass, DecomposeCumsumPass, DecomposeDivPass, + DecomposeEluPass, DecomposeEmbeddingPass, DecomposeExpm1Pass, DecomposeGeluPass, @@ -132,6 +134,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) + self.add_pass(ConvertELUParamsPass()) self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) @@ -180,6 +183,8 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeAtanPass()) self.add_pass(DecomposeAtanhPass()) self.add_pass(DecomposeAddmmPass()) + self.add_pass(DecomposeEluPass()) + self.add_pass(DecomposeExpm1Pass()) self.add_pass(ConvertIntPowToMuls()) self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSinhPass()) diff --git a/backends/arm/_passes/convert_elu_params.py b/backends/arm/_passes/convert_elu_params.py new file mode 100644 index 00000000000..7da58ae4bb4 --- /dev/null +++ b/backends/arm/_passes/convert_elu_params.py @@ -0,0 +1,53 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class ConvertELUParamsPass(ExportPass): + """ + Pass to convert the input_scale kwarg of ELU operator from float to + int. + + It has been set to 2 as the outputs seem to stay the same regardless of what + the value of input_scale is, as long as that value is not 1. + """ + + def call(self, graph_module: torch.fx.GraphModule): + modified_graph = False + graph = graph_module.graph + node_list = graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.elu.default + ) + for node in node_list: + with graph.inserting_after(node): + replace_node = create_node(graph, exir_ops.edge.aten.elu.default) + old_args = list(node.args) + + alpha = old_args[1] if len(old_args) > 1 else 1.0 + scale = 1.0 + input_scale = 2.0 + + replace_node.args = (old_args[0],) + + updated_kwargs = dict(node.kwargs) + updated_kwargs["alpha"] = int(alpha) + updated_kwargs["scale"] = int(scale) + updated_kwargs["input_scale"] = int(input_scale) + + replace_node.kwargs = updated_kwargs + + node.replace_all_uses_with(replace_node) + graph.erase_node(node) + + modified_graph = True + if modified_graph: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified_graph) diff --git a/backends/arm/_passes/decompose_elu_pass.py b/backends/arm/_passes/decompose_elu_pass.py new file mode 100644 index 00000000000..743f1b46f4d --- /dev/null +++ b/backends/arm/_passes/decompose_elu_pass.py @@ -0,0 +1,85 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + +edge_elu_ops = (exir_ops.edge.aten.elu.default,) + + +def get_elu_decomposition(op) -> tuple: + """ + Returns the decomposition of the given aten.elu operation into + its equivalent TOSA-supported operations + + This handles both edge dialect ops and core PyTorch ops. The decomposition strategy + is: + elu(x, y) → where(greater_or_eq(x, 0), (exp(x)-1), x) + + Returns: + A tuple (expm1_op, ge_op, where_op, mul_op) corresponding to the appropriate operator + overloads for the input op. + + Raises: + RuntimeError: If the provided operator is not a supported elu variant. + """ + + if op in edge_elu_ops: + return ( + exir_ops.edge.aten.expm1.default, + exir_ops.edge.aten.ge.Scalar, + exir_ops.edge.aten.where.self, + exir_ops.edge.aten.mul.Scalar, + ) + + raise RuntimeError(f"Can't get elu decomposition for op {op}") + + +class DecomposeEluPass(ArmPass): + """ + A transformation pass that decomposes unsupported 'aten.elu' operations + into a combination of supported TOSA-equivalent operations. + + Since TOSA does not provide a native ELU operator, this pass rewrites: + elu(x) → where(greater_or_eq(x, 0), (alpha*(exp(x)-1)), x) + + Supported input ops: + - exir_ops.edge.aten.elu.Tensor(x) + + These are replaced with: + - exir_ops.edge.aten.expm1.default + - exir_ops.edge.aten.ge.Scalar + - exir_ops.edge.aten.where.self + - exir_ops.edge.aten.mul.Scalar + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in edge_elu_ops: + return super().call_operator(op, args, kwargs, meta, updated=False) + + ( + expm1_op, + ge_op, + where_op, + mul_op, + ) = get_elu_decomposition(op) + + input = args[0] + alpha = args[1] if len(args) > 1 else 1.0 + + if alpha == 0: + relu_op = exir_ops.edge.aten.relu.default + return super().call_operator(relu_op, (input,), {}, meta, updated=True) + + expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True) + mul_node = super().call_operator( + mul_op, (expm1_node, alpha), {}, meta, updated=True + ) + ge_node = super().call_operator(ge_op, (input, 0.0), {}, meta, updated=True) + where_node = super().call_operator( + where_op, (ge_node, input, mul_node), {}, meta, updated=True + ) + + return where_node diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 3506ce20f1a..fb5d7de5e12 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -59,6 +59,7 @@ class TableOps: special_table_ops: Set[EdgeOpOverload] = { exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.elu.default, } def __init__(self, exported_program: ExportedProgram): @@ -92,6 +93,11 @@ def __getitem__(self, node: Node): return lambda x: torch.nn.functional.gelu( x, approximate=approximate ).flatten() + case exir_ops.edge.aten.elu.default: + input_alpha = cast(int, node.kwargs["alpha"]) + return lambda x: torch.nn.functional.elu( + x, alpha=input_alpha + ).flatten() case _: # Op must be handled if it's inside self.special_ops raise AssertionError("Unhandled table operation") diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 004f2f1ccc9..d813fbda531 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -263,6 +263,7 @@ def is_node_supported( exir_ops.edge.aten.glu.default, exir_ops.edge.aten.logit.default, exir_ops.edge.aten.acos.default, + exir_ops.edge.aten.elu.default, ] return supported diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 3546b9af716..e6eac4d80eb 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -266,6 +266,7 @@ def _match_pattern( torch.ops.aten.erf.default, torch.ops.aten.exp.default, torch.ops.aten.expm1.default, + torch.ops.aten.elu.default, torch.ops.aten.floor.default, torch.ops.aten.log.default, torch.ops.aten.reciprocal.default, diff --git a/backends/arm/test/ops/test_elu.py b/backends/arm/test/ops/test_elu.py new file mode 100644 index 00000000000..884f54c0202 --- /dev/null +++ b/backends/arm/test/ops/test_elu.py @@ -0,0 +1,133 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.nn as nn + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +test_data_suite = { + # (test_name, test_data) + "zeros_default": lambda: (1.0, torch.zeros(1, 10, 10, 10)), + "ones_default": lambda: (1.0, torch.ones(10, 10, 10)), + "rand_default": lambda: (1.0, torch.rand(10, 10) - 0.5), + "randn_pos_default": lambda: (1.0, torch.randn(1, 2, 3, 3) + 10), + "randn_neg_default": lambda: (1.0, torch.randn(2, 4, 3) - 10), + "ramp_default": lambda: (1.0, torch.arange(-16, 16, 0.2)), + "large_pos_default": lambda: (1.0, torch.randn(3, 3) * 1e6 + 1e7), + "large_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e5, 1e8)), + "small_pos_default": lambda: (1.0, torch.empty(5).uniform_(1e-8, 1e-5)), + "small_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e-8, 1e-5)), + "zeros_custom": lambda: (2.0, torch.zeros(1, 10, 10, 10)), + "ones_custom": lambda: (2.0, torch.ones(10, 10, 10)), + "rand_custom": lambda: (2.0, torch.rand(10, 10) - 0.5), + "randn_pos_custom": lambda: (2.0, torch.randn(1, 3, 3) + 10), + "randn_neg_custom": lambda: (2.0, torch.randn(1, 2, 4, 3) - 10), + "ramp_custom": lambda: (2.0, torch.arange(-16, 16, 0.2)), + "large_pos_custom": lambda: (2.0, torch.randn(3, 3) * 1e6 + 1e7), + "large_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e5, 1e8)), + "small_pos_custom": lambda: (2.0, torch.empty(5).uniform_(1e-8, 1e-5)), + "small_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e-8, 1e-5)), + "zeros_zero": lambda: (0.0, torch.zeros(1, 10, 10, 10)), + "ones_zero": lambda: (0.0, torch.ones(10, 10, 10)), + "rand_zero": lambda: (0.0, torch.rand(10, 10) - 0.5), + "randn_pos_zero": lambda: (0.0, torch.randn(1, 3, 3) + 10), + "randn_neg_zero": lambda: (0.0, torch.randn(1, 2, 4, 3) - 10), + "ramp_zero": lambda: (0.0, torch.arange(-16, 16, 0.2)), + "large_pos_zero": lambda: (0.0, torch.randn(3, 3) * 1e6 + 1e7), + "large_neg_zero": lambda: (0.0, -torch.empty(5).uniform_(1e5, 1e8)), + "small_pos_zero": lambda: (0.0, torch.empty(5).uniform_(1e-8, 1e-5)), + "small_neg_zero": lambda: (0.0, -torch.empty(5).uniform_(1e-8, 1e-5)), +} + + +class Elu(nn.Module): + aten_op = "torch.ops.aten.elu.default" + exir_op = "executorch_exir_dialects_edge__ops_aten__elu_default" + + def __init__(self, input_alpha: float = 1.0): + super().__init__() + self.elu = torch.nn.ELU(alpha=input_alpha) + + def forward(self, input_: torch.Tensor): + return self.elu(input_) + + +input_t1 = Tuple[torch.Tensor] + + +@common.parametrize("test_module", test_data_suite) +def test_elu_tosa_FP(test_module: input_t1): + alpha, test_data = test_module() + pipeline = TosaPipelineFP[input_t1]( + Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_suite) +def test_elu_tosa_INT(test_module: input_t1): + alpha, test_data = test_module() + pipeline = TosaPipelineINT[input_t1]( + Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_module", test_data_suite) +def test_elu_u55_INT(test_module: input_t1): + alpha, test_data = test_module() + pipeline = EthosU55PipelineINT[input_t1]( + Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_module", test_data_suite) +def test_elu_u85_INT(test_module: input_t1): + alpha, test_data = test_module() + pipeline = EthosU85PipelineINT[input_t1]( + Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_suite) +def test_elu_vgf_FP(test_module: input_t1): + alpha, test_data = test_module() + pipeline = VgfPipeline[input_t1]( + Elu(alpha), + (test_data,), + aten_op=Elu.aten_op, + exir_op=Elu.exir_op, + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_suite) +def test_elu_vgf_INT(test_module: input_t1): + alpha, test_data = test_module() + pipeline = VgfPipeline[input_t1]( + Elu(alpha), + (test_data,), + aten_op=Elu.aten_op, + exir_op=Elu.exir_op, + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()