diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 1fa626efce1..4bbcec57cba 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -91,6 +91,7 @@ class BaseTOSASupportList(OperatorSupportBase): def is_node_supported(self, submodules, node: fx.Node) -> bool: supported = node.op == "call_function" and node.target in [ + exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.expand_copy.default, exir_ops.edge.aten.cat.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 735debe367f..98d96828ad5 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -7,6 +7,7 @@ from . import ( # noqa node_visitor, + op_abs, op_add, op_avg_pool2d, op_bmm, diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py new file mode 100644 index 00000000000..886a96fd520 --- /dev/null +++ b/backends/arm/operators/op_abs.py @@ -0,0 +1,133 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import executorch.backends.arm.tosa_quant_utils as tqutils +import executorch.backends.arm.tosa_utils as tutils + +import serializer.tosa_serializer as ts # type: ignore +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification + +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class AbsVisitor_080_BI(NodeVisitor): + target = "aten.abs.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + # Specification (0.80) states that input and output types + # should all be the same + if not (inputs[0].dtype == output.dtype): + raise ValueError( + "All inputs and outputs need same dtype." + f"Got {inputs[0].dtype=}, {output.dtype=}" + ) + # Handle int8 (quantized) and int32 + if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]): + raise ValueError( + "All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}" + ) + + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node + ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.abs + rescaled_inputs = inputs + + if output.dtype == ts.DType.INT8: + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) + abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT32 + abs_output = output + + # Do the INT32 Abs + tosa_graph.addOperator( + TosaOp.Op().ABS, + [ + rescaled_inputs[0].name, + ], + [abs_output.name], + None, + ) + + if output.dtype == ts.DType.INT8: + # Scale output back to 8 bit + # pyre-ignore + tqutils.insert_rescale_op_to_int8(tosa_graph, abs_output, scale_back, node) # type: ignore[possibly-undefined] + + +@register_node_visitor +class AbsVisitor_080_MI(AbsVisitor_080_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + # Specification (0.80) states that input and output types + # should all be the same + if not (inputs[0].dtype == output.dtype): + raise ValueError( + "All inputs and output need same dtype." + f"Got {inputs[0].dtype=}, {output.dtype=}" + ) + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output) + else: + # FP32 Abs lowering + + if not (inputs[0].dtype == ts.DType.FP32): + raise ValueError( + "All inputs need to be FP32." f"Got {inputs[0].dtype=}" + ) + + if not (output.dtype == ts.DType.FP32): + raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}") + + # MI lowering + tosa_graph.addOperator( + TosaOp.Op().ABS, + [inputs[0].name], + [output.name], + None, + ) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index f1cef971782..9b2ada035f9 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -125,6 +125,7 @@ def _match_pattern( _one_to_one = [ + torch.ops.aten.abs.default, torch.ops.aten.exp.default, torch.ops.aten.log.default, torch.ops.aten.reciprocal.default, diff --git a/backends/arm/test/ops/test_abs.py b/backends/arm/test/ops/test_abs.py new file mode 100644 index 00000000000..481c7d5ed0d --- /dev/null +++ b/backends/arm/test/ops/test_abs.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import pytest + +import torch +from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +class TestAbs(unittest.TestCase): + class Abs(torch.nn.Module): + test_parameters = [ + (torch.zeros(5),), + (torch.full((5,), -1, dtype=torch.float32),), + (torch.ones(5) * -1,), + (torch.randn(8),), + (torch.randn(2, 3, 4),), + (torch.randn(1, 2, 3, 4),), + (torch.normal(mean=0, std=10, size=(2, 3, 4)),), + ] + + def forward(self, x): + return torch.abs(x) + + def _test_abs_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .check_count({"torch.ops.aten.abs.default": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["torch.ops.aten.abs.default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_abs_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.abs.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_abs_ethosu_BI_pipeline( + self, + compile_spec: list[CompileSpec], + module: torch.nn.Module, + test_data: Tuple[torch.Tensor], + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .check_count({"torch.ops.aten.abs.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + ) + if conftest.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) + + @parameterized.expand(Abs.test_parameters) + def test_abs_tosa_MI(self, test_data: torch.Tensor): + test_data = (test_data,) + self._test_abs_tosa_MI_pipeline(self.Abs(), test_data) + + @parameterized.expand(Abs.test_parameters) + def test_abs_tosa_BI(self, test_data: torch.Tensor): + test_data = (test_data,) + self._test_abs_tosa_BI_pipeline(self.Abs(), test_data) + + @parameterized.expand(Abs.test_parameters) + @pytest.mark.corstone_fvp + def test_abs_u55_BI(self, test_data: torch.Tensor): + test_data = (test_data,) + self._test_abs_ethosu_BI_pipeline( + common.get_u55_compile_spec(), self.Abs(), test_data + ) + + @parameterized.expand(Abs.test_parameters) + @pytest.mark.corstone_fvp + def test_abs_u85_BI(self, test_data: torch.Tensor): + test_data = (test_data,) + self._test_abs_ethosu_BI_pipeline( + common.get_u85_compile_spec(), self.Abs(), test_data + )