diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 406f47c66dc..64affa8507a 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -32,6 +32,7 @@ from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_ne_pass import DecomposeNotEqualPass # noqa +from .decompose_round_pass import DecomposeRoundPass # noqa from .decompose_select import DecomposeSelectPass # noqa from .decompose_silu_pass import DecomposeSiluPass # noqa from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index d386da3ed72..672911ce199 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -35,6 +35,7 @@ DecomposeMaxPool2DPass, DecomposeMeanDimPass, DecomposeNotEqualPass, + DecomposeRoundPass, DecomposeSelectPass, DecomposeSiluPass, DecomposeSoftmaxPass, @@ -139,6 +140,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + self.add_pass(DecomposeRoundPass()) self.add_pass(DecomposeSqrtPass()) self.add_pass(ConvertIntPowToMuls()) self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) @@ -219,6 +221,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(InsertCastForOpsWithInt64InputPass()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(DecomposeScaledDotProductAttention()) + self.add_pass(DecomposeRoundPass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeGroupNormPass()) diff --git a/backends/arm/_passes/decompose_round_pass.py b/backends/arm/_passes/decompose_round_pass.py new file mode 100644 index 00000000000..edfa3817064 --- /dev/null +++ b/backends/arm/_passes/decompose_round_pass.py @@ -0,0 +1,84 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from torch._ops import OpOverload + + +Op = OpOverload | EdgeOpOverload + + +def _get_round_decomposition_ops(op) -> tuple[Op, Op, Op, Op, Op, Op, Op]: + """ + Returns the (full_op, ge_op, add_op, sub_op, floor_op, ceil_op, where_op) for the + given round operation. The ops depend on whether the round op is an aten or edge op. + """ + if op == exir_ops.edge.aten.round.default: + return ( + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.add.Scalar, + exir_ops.edge.aten.sub.Scalar, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.where.self, + ) + elif op == torch.ops.aten.round.default: + return ( + torch.ops.aten.full.default, + torch.ops.aten.ge.Tensor, + torch.ops.aten.add.Scalar, + torch.ops.aten.sub.Scalar, + torch.ops.aten.floor.default, + torch.ops.aten.ceil.default, + torch.ops.aten.where.self, + ) + raise RuntimeError(f"Can't get round decomposition ops for op {op}") + + +class DecomposeRoundPass(ArmPass): + """ + For inputs >= 0, round(x) is equivalent to floor(x + 0.5), and for inputs < 0, + round(x) is equivalent to ceil(x - 0.5). This pass decomposes the round operation into + a sequence of more primitive operations. + Example: + %zero = full((1,), 0.0, dtype=torch.float32) + %is_non_negative = ge(x, %zero) + %plus_half = add(x, 0.5) + %minus_half = sub(x, 0.5) + %floor = floor(%plus_half) + %ceil = ceil(%minus_half) + %result = where(%is_non_negative, %floor, %ceil) + """ + + def call_operator(self, op, args, kwargs, meta, updated=False): + if op not in (exir_ops.edge.aten.round.default, torch.ops.aten.round.default): + return super().call_operator(op, args, kwargs, meta, updated) + x = args[0] + full, ge, add, sub, floor, ceil, where = _get_round_decomposition_ops(op) + zero = super().call_operator( + full, + args=((1,), 0.0), + kwargs={"dtype": torch.float32}, + meta=meta, + updated=True, + ) + is_non_negative = super().call_operator( + ge, (x, zero), kwargs, meta, updated=True + ) + plus_half = super().call_operator(add, (x, 0.5), kwargs, meta, updated=True) + minus_half = super().call_operator(sub, (x, 0.5), kwargs, meta, updated=True) + floor = super().call_operator(floor, (plus_half,), kwargs, meta, updated=True) + ceil = super().call_operator(ceil, (minus_half,), kwargs, meta, updated=True) + return super().call_operator( + where, + (is_non_negative, floor, ceil), + kwargs, + meta, + updated=True, + ) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 77d2f1011fa..7a893acaf80 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -211,6 +211,7 @@ def is_node_supported( exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.round.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten._log_softmax.default, @@ -281,6 +282,7 @@ def is_node_supported( exir_ops.edge.aten.ne.Scalar: None, exir_ops.edge.aten.div.Scalar: None, exir_ops.edge.aten.leaky_relu.default: None, + exir_ops.edge.aten.round.default: None, } if node.target in needs_decomp_dict: diff --git a/backends/arm/test/ops/test_round.py b/backends/arm/test/ops/test_round.py new file mode 100644 index 00000000000..3480076a3e1 --- /dev/null +++ b/backends/arm/test/ops/test_round.py @@ -0,0 +1,84 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import pytest +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +input_t1 = Tuple[torch.Tensor] # Input x + +aten_op = "torch.ops.aten.round.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_round_default" + +test_data_suite = { + # (test_name, test_data) + "zeros": lambda: torch.zeros(1, 10, 10, 10), + "ones": lambda: torch.ones(10, 10, 10), + "rand": lambda: torch.rand(10, 10) - 0.5, + "randn_pos": lambda: torch.randn(10) + 10, + "randn_neg": lambda: torch.randn(10) - 10, + "ramp": lambda: torch.arange(-16, 16, 0.2), +} + + +class Round(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x.round() + + +@common.parametrize("test_data", test_data_suite) +def test_round_tosa_MI(test_data: torch.Tensor): + pipeline = TosaPipelineMI[input_t1]( + Round(), + (test_data(),), + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_round_tosa_BI(test_data: torch.Tensor): + pipeline = TosaPipelineBI[input_t1]( + Round(), + (test_data(),), + [], + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail(reason="where.self not supported on U55") +def test_round_u55_BI(test_data: torch.Tensor): + pipeline = EthosU55PipelineBI[input_t1]( + Round(), + (test_data(),), + [], + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_round_u85_BI(test_data: torch.Tensor): + pipeline = EthosU85PipelineBI[input_t1]( + Round(), + (test_data(),), + [], + exir_op, + ) + pipeline.run()