Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
from .decompose_asin_pass import DecomposeAsinPass # noqa
from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
from .decompose_atan_pass import DecomposeAtanPass # noqa
from .decompose_atanh_pass import DecomposeAtanhPass # noqa
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
DecomposeAcoshPass,
DecomposeAdaptiveAvgPool2dPass,
DecomposeAddmmPass,
DecomposeAsinAndAcosPass,
DecomposeAsinhPass,
DecomposeAsinPass,
DecomposeAtanhPass,
DecomposeAtanPass,
DecomposeAvgPool2d,
Expand Down Expand Up @@ -171,9 +171,9 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(DecomposeMaskedFill())
self.add_pass(DecomposeRoundPass())
self.add_pass(DecomposeAcoshPass())
self.add_pass(DecomposeAsinPass())
self.add_pass(DecomposeAsinhPass())
self.add_pass(DecomposeCoshPass())
self.add_pass(DecomposeAsinAndAcosPass())
self.add_pass(DecomposeSqrtPass())
self.add_pass(DecomposeAtanPass())
self.add_pass(DecomposeAtanhPass())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

# For MI case
edge_asin_op = (exir_ops.edge.aten.asin.default,)
edge_acos_op = (exir_ops.edge.aten.acos.default,)


def get_asin_decomposition(op) -> tuple:
if op in edge_asin_op:
def get_decomposition(op) -> tuple:
if op in (edge_asin_op + edge_acos_op):
return (
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.add.Tensor,
Expand All @@ -31,25 +32,26 @@ def get_asin_decomposition(op) -> tuple:
exir_ops.edge.aten.lt.Scalar,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.where.self,
exir_ops.edge.aten.neg.default,
)

raise RuntimeError(f"Can't get asin decomposition for op {op}")
raise RuntimeError(f"Can't get decomposition for op {op}")


class DecomposeAsinPass(ArmPass):
class DecomposeAsinAndAcosPass(ArmPass):
"""
This pass decomposes asin into a rational approximation for small values
This pass decomposes asin and acos into a rational approximation for small values
and a transformed rational approximation for large values.
Example:
y = asin(x)
Becomes:

The decomposition is based on the following mathematical identities:
if abs(x) < 0.5:
y = x + P(x^2) / Q(x^2)
asin(x) = x + P(x^2) / Q(x^2)
acos(x) = π/2 - asin(x)
else:
y = π/2 - 2 * (s + s^3 * Q(z) / P(z))
where P and Q are polynomials defined in the function.
asin(x) = π/2 - 2 * (s + s^3 * Q(z) / P(z))
acos(x) = 2 * (s + s^3 * Q(z) / P(z))
where P and Q are polynomials defined in the function and s is the square root of z.

"""

def _build_polynomial(
Expand Down Expand Up @@ -84,11 +86,25 @@ def _build_polynomial(
)
return result

def _combine_branches(
self,
bool_op,
bool_args: tuple[torch.Tensor, float],
branches: tuple[torch.Tensor, torch.Tensor],
meta: dict[str, str],
) -> torch.Tensor:
where_op = exir_ops.edge.aten.where.self
mask = super().call_operator(bool_op, bool_args, {}, meta, True)
branch_true, branch_false = branches
return super().call_operator(
where_op, (mask, branch_true, branch_false), {}, meta, True
)

def call_operator(self, op, args, kwargs, meta):
if op not in edge_asin_op:
if op not in (edge_asin_op + edge_acos_op):
return super().call_operator(op, args, kwargs, meta)
logging.info(
f"Approximating asin. This may introduce small numerical errors. For details, see {__file__}."
f"Approximating {op}. This may introduce small numerical errors. For details, see {__file__}."
)
x = args[0]
half = 0.5
Expand All @@ -111,9 +127,8 @@ def call_operator(self, op, args, kwargs, meta):
lt_op,
sub_op,
full_like_op,
where_op,
neg_op,
) = get_asin_decomposition(op)
) = get_decomposition(op)

# Coefficients for the rational approximation, calculated with the Minimax (Remez) method
p_coefficients = [
Expand All @@ -129,7 +144,6 @@ def call_operator(self, op, args, kwargs, meta):
x_abs = super().call_operator(abs_op, (x,), {}, meta, True)

# Step 1: compute asin_small - rational approximation for [0,0.5]

y = super().call_operator(mul_op, (x_abs, x_abs), {}, meta, True)
x3 = super().call_operator(mul_op, (x_abs, y), {}, meta, True)

Expand All @@ -154,47 +168,40 @@ def call_operator(self, op, args, kwargs, meta):
Qz = self._build_polynomial(q_coefficients, z, meta)

numer = super().call_operator(mul_op, (s3, Pz), {}, meta, True)

# Calculate r_large = P(z) / Q(z)
r_large = super().call_operator(div_op, (numer, Qz), {}, meta, True)

# Calculate asin_large = pi/2 - 2 * (s + s^3 * Q(z) / P(z))
t1 = super().call_operator(add_op, (s, r_large), {}, meta, True)
t2 = super().call_operator(mul_op_scalar, (t1, two), {}, meta, True)

diff = super().call_operator(sub_op_scalar, (t2, pi_over_2), {}, meta, True)
tmp_neg_ones = super().call_operator(
full_like_op, (diff, neg_one), {}, meta, True
)
asin_large = super().call_operator(mul_op, (diff, tmp_neg_ones), {}, meta, True)

# Combine branches
is_large = super().call_operator(gt_op, (x_abs, half), {}, meta, True)
asin_unsigned = super().call_operator(
where_op,
(
is_large,
asin_large,
asin_small,
),
{},
meta,
True,
asin_unsigned = self._combine_branches(
gt_op, (x_abs, half), (asin_large, asin_small), meta
)

# Handle x < 0
is_neg = super().call_operator(lt_op, (x, zero), {}, meta, True)
# Compute -asin_unsigned
negated_asin = super().call_operator(neg_op, (asin_unsigned,), {}, meta, True)
# Combine branches for signed asin
asin_signed = super().call_operator(
where_op,
(
is_neg,
negated_asin,
asin_unsigned,
),
{},
meta,
True,
asin = self._combine_branches(
lt_op, (x, zero), (negated_asin, asin_unsigned), meta
)

return asin_signed
if op in edge_acos_op:
# If x <= 0.5: acos(x) = pi/2 - asin(x)
const_tensor = super().call_operator(
full_like_op, (x, pi_over_2), {}, meta, True
)
acos_small = super().call_operator(
sub_op, (const_tensor, asin), {}, meta, True
)
# If x > 0.5, acos(x) = 2 * (s + s^3 * Q(z) / P(z)) = t2
acos = self._combine_branches(gt_op, (x, half), (t2, acos_small), meta)
return acos

return asin
1 change: 1 addition & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class TableOps:
exir_ops.edge.aten.asin.default: torch.asin,
exir_ops.edge.aten.asinh.default: torch.asinh,
exir_ops.edge.aten.cosh.default: torch.cosh,
exir_ops.edge.aten.acos.default: torch.acos,
}

# Targets that must be treated explicitly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def is_node_supported(
exir_ops.edge.aten.cosh.default,
exir_ops.edge.aten.glu.default,
exir_ops.edge.aten.logit.default,
exir_ops.edge.aten.acos.default,
]

return supported
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def _match_pattern(
torch.ops.aten.atanh.default,
torch.ops.aten.asinh.default,
torch.ops.aten.cosh.default,
torch.ops.aten.acos.default,
]

_one_to_one_shared_input_qspec = [
Expand Down
119 changes: 119 additions & 0 deletions backends/arm/test/ops/test_acos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple

import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)

input_t = Tuple[torch.Tensor]
aten_op = "torch.ops.aten.acos.default"
exir_op = "executorch_exir_dialects_edge__ops_aten__acos_default"


test_data_suite = {
"ones": lambda: torch.ones(1, 7, 10, 12),
"rand_in_range": lambda: (torch.rand(10, 10) - 0.5) * 2, # Uniform in [-1, 1)
"ramp_valid": lambda: torch.linspace(-1.0, 1.0, steps=160),
"edge_cases": lambda: torch.tensor([-1.0, 0.0, 1.0]),
"1d_tensor": lambda: torch.linspace(-1.0, 1.0, steps=10), # Shape: [10]
"2d_batch": lambda: torch.tensor(
[[-1.0, -0.5, 0.0, 0.5, 1.0], [0.9, -0.9, 0.3, -0.3, 0.0]]
), # Shape: [2, 5]
"3d_batch": lambda: torch.rand(4, 5, 6) * 2 - 1, # Shape: [4, 5, 6] in [-1, 1)
"3d_mixed_shape": lambda: (torch.rand(7, 15, 2) - 0.5) * 2,
"4d_mixed": lambda: torch.linspace(-1, 1, steps=1 * 3 * 4 * 5).reshape(
1, 3, 4, 5
), # Shape: [2, 3, 4, 5]
"4d_random": lambda: (torch.rand(1, 5, 10, 7) - 0.5) * 2,
"bool_casted": lambda: torch.ones(3, 3, dtype=torch.bool).to(
dtype=torch.float32
), # All 1.0 (edge case)
}


class Acos(torch.nn.Module):

def forward(self, x: torch.Tensor):
return torch.acos(x)


@common.parametrize("test_data", test_data_suite)
def test_acos_tosa_FP(test_data: Tuple):
pipeline = TosaPipelineFP[input_t](
Acos(),
(test_data(),),
aten_op,
exir_op=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_acos_tosa_INT(test_data: Tuple):
pipeline = TosaPipelineINT[input_t](
Acos(),
(test_data(),),
aten_op=aten_op,
exir_op=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone300
def test_acos_u55_INT(test_data: Tuple):
pipeline = EthosU55PipelineINT[input_t](
Acos(),
(test_data(),),
aten_ops=aten_op,
exir_ops=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone320
def test_acos_u85_INT(test_data: Tuple):
pipeline = EthosU85PipelineINT[input_t](
Acos(),
(test_data(),),
aten_ops=aten_op,
exir_ops=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.SkipIfNoModelConverter
def test_acos_vgf_FP(test_data: Tuple):
pipeline = VgfPipeline[input_t](
Acos(),
(test_data(),),
[],
[],
tosa_version="TOSA-1.0+FP",
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.SkipIfNoModelConverter
def test_acos_vgf_INT(test_data: Tuple):
pipeline = VgfPipeline[input_t](
Acos(),
(test_data(),),
[],
[],
tosa_version="TOSA-1.0+INT",
)
pipeline.run()
Loading