Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .cast_to_int32_pass import CastToInt32Pass # noqa
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
from .convert_elu_params import ConvertELUParamsPass # noqa
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
Expand All @@ -34,6 +35,7 @@
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_elu_pass import DecomposeEluPass # noqa
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
from .decompose_gelu_pass import DecomposeGeluPass # noqa
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ComputeConstantOpsAOT,
Conv1dUnsqueezePass,
ConvertAnyDefaultDimDimsPass,
ConvertELUParamsPass,
ConvertExpandCopyToRepeatPass,
ConvertFullLikeToFullPass,
ConvertIntPowToMuls,
Expand All @@ -39,6 +40,7 @@
DecomposeCosineSimilarityPass,
DecomposeCumsumPass,
DecomposeDivPass,
DecomposeEluPass,
DecomposeEmbeddingPass,
DecomposeExpm1Pass,
DecomposeGeluPass,
Expand Down Expand Up @@ -132,6 +134,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(ConvertELUParamsPass())
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
Expand Down Expand Up @@ -180,6 +183,8 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(DecomposeAtanPass())
self.add_pass(DecomposeAtanhPass())
self.add_pass(DecomposeAddmmPass())
self.add_pass(DecomposeEluPass())
self.add_pass(DecomposeExpm1Pass())
self.add_pass(ConvertIntPowToMuls())
self.add_pass(CastBoolToInt8Pass())
self.add_pass(DecomposeSinhPass())
Expand Down
53 changes: 53 additions & 0 deletions backends/arm/_passes/convert_elu_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class ConvertELUParamsPass(ExportPass):
"""
Pass to convert the input_scale kwarg of ELU operator from float to
int.

It has been set to 2 as the outputs seem to stay the same regardless of what
the value of input_scale is, as long as that value is not 1.
"""

def call(self, graph_module: torch.fx.GraphModule):
modified_graph = False
graph = graph_module.graph
node_list = graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.elu.default
)
for node in node_list:
with graph.inserting_after(node):
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
old_args = list(node.args)

alpha = old_args[1] if len(old_args) > 1 else 1.0
scale = 1.0
input_scale = 2.0

replace_node.args = (old_args[0],)

updated_kwargs = dict(node.kwargs)
updated_kwargs["alpha"] = int(alpha)
updated_kwargs["scale"] = int(scale)
updated_kwargs["input_scale"] = int(input_scale)

replace_node.kwargs = updated_kwargs

node.replace_all_uses_with(replace_node)
graph.erase_node(node)

modified_graph = True
if modified_graph:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified_graph)
85 changes: 85 additions & 0 deletions backends/arm/_passes/decompose_elu_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops

edge_elu_ops = (exir_ops.edge.aten.elu.default,)


def get_elu_decomposition(op) -> tuple:
"""
Returns the decomposition of the given aten.elu operation into
its equivalent TOSA-supported operations

This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
is:
elu(x, y) → where(greater_or_eq(x, 0), (exp(x)-1), x)

Returns:
A tuple (expm1_op, ge_op, where_op, mul_op) corresponding to the appropriate operator
overloads for the input op.

Raises:
RuntimeError: If the provided operator is not a supported elu variant.
"""

if op in edge_elu_ops:
return (
exir_ops.edge.aten.expm1.default,
exir_ops.edge.aten.ge.Scalar,
exir_ops.edge.aten.where.self,
exir_ops.edge.aten.mul.Scalar,
)

raise RuntimeError(f"Can't get elu decomposition for op {op}")


class DecomposeEluPass(ArmPass):
"""
A transformation pass that decomposes unsupported 'aten.elu' operations
into a combination of supported TOSA-equivalent operations.

Since TOSA does not provide a native ELU operator, this pass rewrites:
elu(x) → where(greater_or_eq(x, 0), (alpha*(exp(x)-1)), x)

Supported input ops:
- exir_ops.edge.aten.elu.Tensor(x)

These are replaced with:
- exir_ops.edge.aten.expm1.default
- exir_ops.edge.aten.ge.Scalar
- exir_ops.edge.aten.where.self
- exir_ops.edge.aten.mul.Scalar
"""

def call_operator(self, op, args, kwargs, meta):
if op not in edge_elu_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)

(
expm1_op,
ge_op,
where_op,
mul_op,
) = get_elu_decomposition(op)

input = args[0]
alpha = args[1] if len(args) > 1 else 1.0

if alpha == 0:
relu_op = exir_ops.edge.aten.relu.default
return super().call_operator(relu_op, (input,), {}, meta, updated=True)

expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True)
mul_node = super().call_operator(
mul_op, (expm1_node, alpha), {}, meta, updated=True
)
ge_node = super().call_operator(ge_op, (input, 0.0), {}, meta, updated=True)
where_node = super().call_operator(
where_op, (ge_node, input, mul_node), {}, meta, updated=True
)

return where_node
6 changes: 6 additions & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class TableOps:
special_table_ops: Set[EdgeOpOverload] = {
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.elu.default,
}

def __init__(self, exported_program: ExportedProgram):
Expand Down Expand Up @@ -92,6 +93,11 @@ def __getitem__(self, node: Node):
return lambda x: torch.nn.functional.gelu(
x, approximate=approximate
).flatten()
case exir_ops.edge.aten.elu.default:
input_alpha = cast(int, node.kwargs["alpha"])
return lambda x: torch.nn.functional.elu(
x, alpha=input_alpha
).flatten()
case _:
# Op must be handled if it's inside self.special_ops
raise AssertionError("Unhandled table operation")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def is_node_supported(
exir_ops.edge.aten.glu.default,
exir_ops.edge.aten.logit.default,
exir_ops.edge.aten.acos.default,
exir_ops.edge.aten.elu.default,
]

return supported
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def _match_pattern(
torch.ops.aten.erf.default,
torch.ops.aten.exp.default,
torch.ops.aten.expm1.default,
torch.ops.aten.elu.default,
torch.ops.aten.floor.default,
torch.ops.aten.log.default,
torch.ops.aten.reciprocal.default,
Expand Down
133 changes: 133 additions & 0 deletions backends/arm/test/ops/test_elu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
import torch.nn as nn

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)

test_data_suite = {
# (test_name, test_data)
"zeros_default": lambda: (1.0, torch.zeros(1, 10, 10, 10)),
"ones_default": lambda: (1.0, torch.ones(10, 10, 10)),
"rand_default": lambda: (1.0, torch.rand(10, 10) - 0.5),
"randn_pos_default": lambda: (1.0, torch.randn(1, 2, 3, 3) + 10),
"randn_neg_default": lambda: (1.0, torch.randn(2, 4, 3) - 10),
"ramp_default": lambda: (1.0, torch.arange(-16, 16, 0.2)),
"large_pos_default": lambda: (1.0, torch.randn(3, 3) * 1e6 + 1e7),
"large_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e5, 1e8)),
"small_pos_default": lambda: (1.0, torch.empty(5).uniform_(1e-8, 1e-5)),
"small_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e-8, 1e-5)),
"zeros_custom": lambda: (2.0, torch.zeros(1, 10, 10, 10)),
"ones_custom": lambda: (2.0, torch.ones(10, 10, 10)),
"rand_custom": lambda: (2.0, torch.rand(10, 10) - 0.5),
"randn_pos_custom": lambda: (2.0, torch.randn(1, 3, 3) + 10),
"randn_neg_custom": lambda: (2.0, torch.randn(1, 2, 4, 3) - 10),
"ramp_custom": lambda: (2.0, torch.arange(-16, 16, 0.2)),
"large_pos_custom": lambda: (2.0, torch.randn(3, 3) * 1e6 + 1e7),
"large_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e5, 1e8)),
"small_pos_custom": lambda: (2.0, torch.empty(5).uniform_(1e-8, 1e-5)),
"small_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e-8, 1e-5)),
"zeros_zero": lambda: (0.0, torch.zeros(1, 10, 10, 10)),
"ones_zero": lambda: (0.0, torch.ones(10, 10, 10)),
"rand_zero": lambda: (0.0, torch.rand(10, 10) - 0.5),
"randn_pos_zero": lambda: (0.0, torch.randn(1, 3, 3) + 10),
"randn_neg_zero": lambda: (0.0, torch.randn(1, 2, 4, 3) - 10),
"ramp_zero": lambda: (0.0, torch.arange(-16, 16, 0.2)),
"large_pos_zero": lambda: (0.0, torch.randn(3, 3) * 1e6 + 1e7),
"large_neg_zero": lambda: (0.0, -torch.empty(5).uniform_(1e5, 1e8)),
"small_pos_zero": lambda: (0.0, torch.empty(5).uniform_(1e-8, 1e-5)),
"small_neg_zero": lambda: (0.0, -torch.empty(5).uniform_(1e-8, 1e-5)),
}


class Elu(nn.Module):
aten_op = "torch.ops.aten.elu.default"
exir_op = "executorch_exir_dialects_edge__ops_aten__elu_default"

def __init__(self, input_alpha: float = 1.0):
super().__init__()
self.elu = torch.nn.ELU(alpha=input_alpha)

def forward(self, input_: torch.Tensor):
return self.elu(input_)


input_t1 = Tuple[torch.Tensor]


@common.parametrize("test_module", test_data_suite)
def test_elu_tosa_FP(test_module: input_t1):
alpha, test_data = test_module()
pipeline = TosaPipelineFP[input_t1](
Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op
)
pipeline.run()


@common.parametrize("test_module", test_data_suite)
def test_elu_tosa_INT(test_module: input_t1):
alpha, test_data = test_module()
pipeline = TosaPipelineINT[input_t1](
Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op
)
pipeline.run()


@common.XfailIfNoCorstone300
@common.parametrize("test_module", test_data_suite)
def test_elu_u55_INT(test_module: input_t1):
alpha, test_data = test_module()
pipeline = EthosU55PipelineINT[input_t1](
Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op
)
pipeline.run()


@common.XfailIfNoCorstone320
@common.parametrize("test_module", test_data_suite)
def test_elu_u85_INT(test_module: input_t1):
alpha, test_data = test_module()
pipeline = EthosU85PipelineINT[input_t1](
Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op
)
pipeline.run()


@common.SkipIfNoModelConverter
@common.parametrize("test_module", test_data_suite)
def test_elu_vgf_FP(test_module: input_t1):
alpha, test_data = test_module()
pipeline = VgfPipeline[input_t1](
Elu(alpha),
(test_data,),
aten_op=Elu.aten_op,
exir_op=Elu.exir_op,
tosa_version="TOSA-1.0+FP",
)
pipeline.run()


@common.SkipIfNoModelConverter
@common.parametrize("test_module", test_data_suite)
def test_elu_vgf_INT(test_module: input_t1):
alpha, test_data = test_module()
pipeline = VgfPipeline[input_t1](
Elu(alpha),
(test_data,),
aten_op=Elu.aten_op,
exir_op=Elu.exir_op,
tosa_version="TOSA-1.0+INT",
)
pipeline.run()
Loading