diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 2a75606cb70..f9efa898331 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -22,6 +22,7 @@ from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa from .convert_to_clamp import ConvertToClampPass # noqa +from .decompose_acosh_pass import DecomposeAcoshPass # noqa from .decompose_atan_pass import DecomposeAtanPass # noqa from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 596decd65bb..f4a8af27ff8 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -25,6 +25,7 @@ ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, + DecomposeAcoshPass, DecomposeAtanPass, DecomposeAvgPool2d, DecomposeBatchNormNoStatsPass, @@ -151,6 +152,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeRoundPass()) + self.add_pass(DecomposeAcoshPass()) self.add_pass(DecomposeSqrtPass()) self.add_pass(DecomposeAtanPass()) self.add_pass(ConvertIntPowToMuls()) diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py new file mode 100644 index 00000000000..1d92dd68c4a --- /dev/null +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -0,0 +1,52 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + +# For MI case +edge_acosh_op = exir_ops.edge.aten.acosh.default + + +class DecomposeAcoshPass(ArmPass): + """ + Decomposes acosh to supported TOSA-operations. + This decomposition is based on the mathematical identity: + acosh(x) = log(x + sqrt((x-1)(x+1)) + """ + + def call_operator(self, op, args, kwargs, meta, updated=False): + + if op is not edge_acosh_op: + return super().call_operator(op, args, kwargs, meta, updated) + + log_op, sqrt_op, mul_op, sub_op, add_op, add_op_scalar = ( + exir_ops.edge.aten.log.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Scalar, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.add.Scalar, + ) + + x = args[0] + + # (x-1)(x+1) + sub = super().call_operator(sub_op, (x, 1.0), {}, meta, True) + add = super().call_operator(add_op_scalar, (x, 1.0), {}, meta, True) + mul = super().call_operator(mul_op, (sub, add), {}, meta, True) + + # sqrt((x-1)(x+1)) + sqrt = super().call_operator(sqrt_op, (mul,), {}, meta, True) + + # x + sqrt((x-1)(x+1)) + add = super().call_operator(add_op, (x, sqrt), {}, meta, True) + + # out = ln(x + sqrt((x-1)(x+1)) + out = super().call_operator(log_op, (add,), {}, meta, True) + + return out diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index b31b6c7106d..28b4700ce39 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -55,6 +55,7 @@ class TableOps: exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid, exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish, exir_ops.edge.aten.sinh.default: torch.sinh, + exir_ops.edge.aten.acosh.default: torch.acosh, } # Targets that must be treated explicitly diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index cdb27b7c31e..3163b4841dc 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -245,6 +245,7 @@ def is_node_supported( exir_ops.edge.aten.alias_copy.default, exir_ops.edge.aten.sinh.default, exir_ops.edge.aten.atan.default, + exir_ops.edge.aten.acosh.default, ] return supported diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 2c61aea60c3..d3afa4149ba 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -215,6 +215,7 @@ def _match_pattern( torch.ops.aten.gelu.default, torch.ops.aten.sinh.default, torch.ops.aten.atan.default, + torch.ops.aten.acosh.default, ] _one_to_one_shared_input_qspec = [ diff --git a/backends/arm/test/ops/test_acosh.py b/backends/arm/test/ops/test_acosh.py new file mode 100644 index 00000000000..00742105b63 --- /dev/null +++ b/backends/arm/test/ops/test_acosh.py @@ -0,0 +1,114 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Tuple + +import pytest + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +input_t = Tuple[torch.Tensor] # Input x +aten_op = "torch.ops.aten.acosh.default" + + +test_data_suite = { + # Valid input cases + "ones": lambda: torch.ones(1, 7, 10, 12), + "just_above_one": lambda: torch.tensor([1.0001, 1.01, 1.1, 2.0]), + "rand_valid": lambda: torch.rand(10, 10) * 10 + 1, # [1, 11) + "ramp_valid": lambda: torch.linspace(1.0, 20.0, steps=160), + "large": lambda: torch.tensor([10.0, 100.0, 1000.0, 1e6]), + "mixed_valid": lambda: torch.tensor([1.0, 2.0, 10.0, 100.0]), +} + +test_data_suite_xfails = { + # Invalid input cases (should return nan or error) + "zeros": lambda: torch.zeros(1, 5, 3, 2), + "neg_ones": lambda: -torch.ones(10, 10, 10), + "rand_invalid": lambda: torch.rand(10, 10), # [0, 1) + "ramp_invalid": lambda: torch.linspace(-10.0, 0.99, steps=160), + "near_zero": lambda: torch.tensor([-1e-6, 0.0, 1e-6]), + "large_negative": lambda: torch.tensor([-100.0, -10.0, 0.0]), +} + + +class Acosh(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return torch.acosh(x) + + +@common.parametrize("test_data", test_data_suite) +def test_acosh_tosa_MI(test_data: Tuple): + pipeline = TosaPipelineMI[input_t]( + Acosh(), + (test_data(),), + aten_op, + exir_op=[], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_acosh_tosa_BI(test_data: Tuple): + pipeline = TosaPipelineBI[input_t]( + Acosh(), + (test_data(),), + aten_op=[], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_acosh_u55_BI(test_data: Tuple): + pipeline = EthosU55PipelineBI[input_t]( + Acosh(), + (test_data(),), + aten_ops=[], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_xfails) +@pytest.mark.xfail(reason="Invalid inputs are currently not handled") +def test_acosh_u55_BI_xfail(test_data: Tuple): + pipeline = EthosU55PipelineBI[input_t]( + Acosh(), + (test_data(),), + aten_ops=[], + run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_acosh_u85_BI(test_data: Tuple): + pipeline = EthosU85PipelineBI[input_t]( + Acosh(), + (test_data(),), + aten_ops=[], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_xfails) +@pytest.mark.xfail(reason="Invalid inputs are currently not handled") +def test_acosh_u85_BI_xfail(test_data: Tuple): + pipeline = EthosU55PipelineBI[input_t]( + Acosh(), + (test_data(),), + aten_ops=[], + run_on_fvp=False, + ) + pipeline.run()