From e493cb7f324c333dae6661338c6b90cfd1a77d12 Mon Sep 17 00:00:00 2001 From: Emma Kujala Date: Fri, 25 Jul 2025 10:16:30 +0200 Subject: [PATCH] Arm backend: Add decomposition and test for acos Signed-off-by: Emma Kujala Change-Id: If63c1fca67194925f0c0a1fd412c3c5d0d52b187 --- backends/arm/_passes/__init__.py | 2 +- backends/arm/_passes/arm_pass_manager.py | 4 +- ...ass.py => decompose_asin_and_acos_pass.py} | 93 +++++++------- backends/arm/_passes/insert_table_ops.py | 1 + .../tosa_supported_operators.py | 1 + .../arm/quantizer/quantization_annotator.py | 1 + backends/arm/test/ops/test_acos.py | 119 ++++++++++++++++++ 7 files changed, 175 insertions(+), 46 deletions(-) rename backends/arm/_passes/{decompose_asin_pass.py => decompose_asin_and_acos_pass.py} (72%) create mode 100644 backends/arm/test/ops/test_acos.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index aa19778306f..4cb7c55f5b1 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -25,7 +25,7 @@ from .decompose_acosh_pass import DecomposeAcoshPass # noqa from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa from .decompose_addmm_pass import DecomposeAddmmPass # noqa -from .decompose_asin_pass import DecomposeAsinPass # noqa +from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa from .decompose_asinh_pass import DecomposeAsinhPass # noqa from .decompose_atan_pass import DecomposeAtanPass # noqa from .decompose_atanh_pass import DecomposeAtanhPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 69feb112701..df3cc270387 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -30,8 +30,8 @@ DecomposeAcoshPass, DecomposeAdaptiveAvgPool2dPass, DecomposeAddmmPass, + DecomposeAsinAndAcosPass, DecomposeAsinhPass, - DecomposeAsinPass, DecomposeAtanhPass, DecomposeAtanPass, DecomposeAvgPool2d, @@ -171,9 +171,9 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeMaskedFill()) self.add_pass(DecomposeRoundPass()) self.add_pass(DecomposeAcoshPass()) - self.add_pass(DecomposeAsinPass()) self.add_pass(DecomposeAsinhPass()) self.add_pass(DecomposeCoshPass()) + self.add_pass(DecomposeAsinAndAcosPass()) self.add_pass(DecomposeSqrtPass()) self.add_pass(DecomposeAtanPass()) self.add_pass(DecomposeAtanhPass()) diff --git a/backends/arm/_passes/decompose_asin_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py similarity index 72% rename from backends/arm/_passes/decompose_asin_pass.py rename to backends/arm/_passes/decompose_asin_and_acos_pass.py index 0c0bcdf7f49..e067f17b0ca 100644 --- a/backends/arm/_passes/decompose_asin_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -15,10 +15,11 @@ # For MI case edge_asin_op = (exir_ops.edge.aten.asin.default,) +edge_acos_op = (exir_ops.edge.aten.acos.default,) -def get_asin_decomposition(op) -> tuple: - if op in edge_asin_op: +def get_decomposition(op) -> tuple: + if op in (edge_asin_op + edge_acos_op): return ( exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.add.Tensor, @@ -31,25 +32,26 @@ def get_asin_decomposition(op) -> tuple: exir_ops.edge.aten.lt.Scalar, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.full_like.default, - exir_ops.edge.aten.where.self, exir_ops.edge.aten.neg.default, ) - raise RuntimeError(f"Can't get asin decomposition for op {op}") + raise RuntimeError(f"Can't get decomposition for op {op}") -class DecomposeAsinPass(ArmPass): +class DecomposeAsinAndAcosPass(ArmPass): """ - This pass decomposes asin into a rational approximation for small values + This pass decomposes asin and acos into a rational approximation for small values and a transformed rational approximation for large values. - Example: - y = asin(x) - Becomes: + + The decomposition is based on the following mathematical identities: if abs(x) < 0.5: - y = x + P(x^2) / Q(x^2) + asin(x) = x + P(x^2) / Q(x^2) + acos(x) = π/2 - asin(x) else: - y = π/2 - 2 * (s + s^3 * Q(z) / P(z)) - where P and Q are polynomials defined in the function. + asin(x) = π/2 - 2 * (s + s^3 * Q(z) / P(z)) + acos(x) = 2 * (s + s^3 * Q(z) / P(z)) + where P and Q are polynomials defined in the function and s is the square root of z. + """ def _build_polynomial( @@ -84,11 +86,25 @@ def _build_polynomial( ) return result + def _combine_branches( + self, + bool_op, + bool_args: tuple[torch.Tensor, float], + branches: tuple[torch.Tensor, torch.Tensor], + meta: dict[str, str], + ) -> torch.Tensor: + where_op = exir_ops.edge.aten.where.self + mask = super().call_operator(bool_op, bool_args, {}, meta, True) + branch_true, branch_false = branches + return super().call_operator( + where_op, (mask, branch_true, branch_false), {}, meta, True + ) + def call_operator(self, op, args, kwargs, meta): - if op not in edge_asin_op: + if op not in (edge_asin_op + edge_acos_op): return super().call_operator(op, args, kwargs, meta) logging.info( - f"Approximating asin. This may introduce small numerical errors. For details, see {__file__}." + f"Approximating {op}. This may introduce small numerical errors. For details, see {__file__}." ) x = args[0] half = 0.5 @@ -111,9 +127,8 @@ def call_operator(self, op, args, kwargs, meta): lt_op, sub_op, full_like_op, - where_op, neg_op, - ) = get_asin_decomposition(op) + ) = get_decomposition(op) # Coefficients for the rational approximation, calculated with the Minimax (Remez) method p_coefficients = [ @@ -129,7 +144,6 @@ def call_operator(self, op, args, kwargs, meta): x_abs = super().call_operator(abs_op, (x,), {}, meta, True) # Step 1: compute asin_small - rational approximation for [0,0.5] - y = super().call_operator(mul_op, (x_abs, x_abs), {}, meta, True) x3 = super().call_operator(mul_op, (x_abs, y), {}, meta, True) @@ -154,47 +168,40 @@ def call_operator(self, op, args, kwargs, meta): Qz = self._build_polynomial(q_coefficients, z, meta) numer = super().call_operator(mul_op, (s3, Pz), {}, meta, True) + # Calculate r_large = P(z) / Q(z) r_large = super().call_operator(div_op, (numer, Qz), {}, meta, True) # Calculate asin_large = pi/2 - 2 * (s + s^3 * Q(z) / P(z)) t1 = super().call_operator(add_op, (s, r_large), {}, meta, True) t2 = super().call_operator(mul_op_scalar, (t1, two), {}, meta, True) + diff = super().call_operator(sub_op_scalar, (t2, pi_over_2), {}, meta, True) tmp_neg_ones = super().call_operator( full_like_op, (diff, neg_one), {}, meta, True ) asin_large = super().call_operator(mul_op, (diff, tmp_neg_ones), {}, meta, True) - # Combine branches - is_large = super().call_operator(gt_op, (x_abs, half), {}, meta, True) - asin_unsigned = super().call_operator( - where_op, - ( - is_large, - asin_large, - asin_small, - ), - {}, - meta, - True, + asin_unsigned = self._combine_branches( + gt_op, (x_abs, half), (asin_large, asin_small), meta ) # Handle x < 0 - is_neg = super().call_operator(lt_op, (x, zero), {}, meta, True) - # Compute -asin_unsigned negated_asin = super().call_operator(neg_op, (asin_unsigned,), {}, meta, True) - # Combine branches for signed asin - asin_signed = super().call_operator( - where_op, - ( - is_neg, - negated_asin, - asin_unsigned, - ), - {}, - meta, - True, + asin = self._combine_branches( + lt_op, (x, zero), (negated_asin, asin_unsigned), meta ) - return asin_signed + if op in edge_acos_op: + # If x <= 0.5: acos(x) = pi/2 - asin(x) + const_tensor = super().call_operator( + full_like_op, (x, pi_over_2), {}, meta, True + ) + acos_small = super().call_operator( + sub_op, (const_tensor, asin), {}, meta, True + ) + # If x > 0.5, acos(x) = 2 * (s + s^3 * Q(z) / P(z)) = t2 + acos = self._combine_branches(gt_op, (x, half), (t2, acos_small), meta) + return acos + + return asin diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index bead5e993f5..4e58f9d9b9b 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -61,6 +61,7 @@ class TableOps: exir_ops.edge.aten.asin.default: torch.asin, exir_ops.edge.aten.asinh.default: torch.asinh, exir_ops.edge.aten.cosh.default: torch.cosh, + exir_ops.edge.aten.acos.default: torch.acos, } # Targets that must be treated explicitly diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 26a728677b5..fa0b2d69a70 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -261,6 +261,7 @@ def is_node_supported( exir_ops.edge.aten.cosh.default, exir_ops.edge.aten.glu.default, exir_ops.edge.aten.logit.default, + exir_ops.edge.aten.acos.default, ] return supported diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index a6f5671a881..84e493a6dba 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -289,6 +289,7 @@ def _match_pattern( torch.ops.aten.atanh.default, torch.ops.aten.asinh.default, torch.ops.aten.cosh.default, + torch.ops.aten.acos.default, ] _one_to_one_shared_input_qspec = [ diff --git a/backends/arm/test/ops/test_acos.py b/backends/arm/test/ops/test_acos.py new file mode 100644 index 00000000000..102d979352e --- /dev/null +++ b/backends/arm/test/ops/test_acos.py @@ -0,0 +1,119 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +input_t = Tuple[torch.Tensor] +aten_op = "torch.ops.aten.acos.default" +exir_op = "executorch_exir_dialects_edge__ops_aten__acos_default" + + +test_data_suite = { + "ones": lambda: torch.ones(1, 7, 10, 12), + "rand_in_range": lambda: (torch.rand(10, 10) - 0.5) * 2, # Uniform in [-1, 1) + "ramp_valid": lambda: torch.linspace(-1.0, 1.0, steps=160), + "edge_cases": lambda: torch.tensor([-1.0, 0.0, 1.0]), + "1d_tensor": lambda: torch.linspace(-1.0, 1.0, steps=10), # Shape: [10] + "2d_batch": lambda: torch.tensor( + [[-1.0, -0.5, 0.0, 0.5, 1.0], [0.9, -0.9, 0.3, -0.3, 0.0]] + ), # Shape: [2, 5] + "3d_batch": lambda: torch.rand(4, 5, 6) * 2 - 1, # Shape: [4, 5, 6] in [-1, 1) + "3d_mixed_shape": lambda: (torch.rand(7, 15, 2) - 0.5) * 2, + "4d_mixed": lambda: torch.linspace(-1, 1, steps=1 * 3 * 4 * 5).reshape( + 1, 3, 4, 5 + ), # Shape: [2, 3, 4, 5] + "4d_random": lambda: (torch.rand(1, 5, 10, 7) - 0.5) * 2, + "bool_casted": lambda: torch.ones(3, 3, dtype=torch.bool).to( + dtype=torch.float32 + ), # All 1.0 (edge case) +} + + +class Acos(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return torch.acos(x) + + +@common.parametrize("test_data", test_data_suite) +def test_acos_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t]( + Acos(), + (test_data(),), + aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_acos_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t]( + Acos(), + (test_data(),), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_acos_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t]( + Acos(), + (test_data(),), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_acos_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t]( + Acos(), + (test_data(),), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_acos_vgf_FP(test_data: Tuple): + pipeline = VgfPipeline[input_t]( + Acos(), + (test_data(),), + [], + [], + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_acos_vgf_INT(test_data: Tuple): + pipeline = VgfPipeline[input_t]( + Acos(), + (test_data(),), + [], + [], + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()