From a90f1aa0aba016e59d6804cd0a647e7f6330fc6f Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 16 Sep 2024 18:00:45 +0200 Subject: [PATCH] Support bitwise and, xor, and or ops in Arm backend Ops are very similar and thus clumped together. No quantization since that doesn't make sense for bitwise ops. Add a factory for creating simple two input NodeVisitors. This can be extended for future such ops. Signed-off-by: Erik Lundell Change-Id: Ic3615067cd03b8775f4b028f601b6b2366487c9d --- backends/arm/operator_support/__init__.py | 1 + .../arm/operator_support/bitwise_support.py | 33 +++ backends/arm/operators/__init__.py | 1 + backends/arm/operators/ops_binary.py | 51 +++++ backends/arm/test/ops/test_bitwise.py | 195 ++++++++++++++++++ 5 files changed, 281 insertions(+) create mode 100644 backends/arm/operator_support/bitwise_support.py create mode 100644 backends/arm/operators/ops_binary.py create mode 100644 backends/arm/test/ops/test_bitwise.py diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index c6895cce492..2ac23b0e91b 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -6,6 +6,7 @@ # pyre-unsafe from . import ( # noqa + bitwise_support, convolution_support, pool_2d_support, reduce_sum_support, diff --git a/backends/arm/operator_support/bitwise_support.py b/backends/arm/operator_support/bitwise_support.py new file mode 100644 index 00000000000..e0604622064 --- /dev/null +++ b/backends/arm/operator_support/bitwise_support.py @@ -0,0 +1,33 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch.fx as fx +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class BitwiseSupported(SupportedTOSAOperatorCheck): + targets = [ + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + ] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + # U55 case, Vela 4.2.0 (25.02 release) + if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: + return False + + return True diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index f57ba092bc4..d42940b3740 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -44,4 +44,5 @@ op_transpose, op_upsample_nearest2d, op_view, + ops_binary, ) diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py new file mode 100644 index 00000000000..9ce561d0b6d --- /dev/null +++ b/backends/arm/operators/ops_binary.py @@ -0,0 +1,51 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import List + +import serializer.tosa_serializer as ts +import torch +import torch.fx + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +def binary_operator_factory(bw_target: str, tosa_op): + """Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op.""" + + class BinaryOperator(NodeVisitor): + target = bw_target + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): + raise ValueError( + "All inputs and outputs need same dtype." + f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}." + ) + + tosa_graph.addOperator( + tosa_op, [inputs[0].name, inputs[1].name], [output.name] + ) + + register_node_visitor(BinaryOperator) + + +binary_operator_factory("aten.bitwise_and.Tensor", TosaOp.Op().BITWISE_AND) +binary_operator_factory("aten.bitwise_xor.Tensor", TosaOp.Op().BITWISE_XOR) +binary_operator_factory("aten.bitwise_or.Tensor", TosaOp.Op().BITWISE_OR) diff --git a/backends/arm/test/ops/test_bitwise.py b/backends/arm/test/ops/test_bitwise.py new file mode 100644 index 00000000000..59d82a6564f --- /dev/null +++ b/backends/arm/test/ops/test_bitwise.py @@ -0,0 +1,195 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Callable, NamedTuple, Tuple + +import torch +from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + + +class DataTuple(NamedTuple): + name: str + tensor1: torch.Tensor + tensor2: torch.Tensor + + +class OpTuple(NamedTuple): + name: str + operator: torch.nn.Module + + +class And(torch.nn.Module): + def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor): + return tensor1.bitwise_and(tensor2) + + +class Xor(torch.nn.Module): + def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor): + return tensor1.bitwise_xor(tensor2) + + +class Or(torch.nn.Module): + def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor): + return tensor1.bitwise_or(tensor2) + + +test_data_suite: list[DataTuple] = [ + DataTuple( + "zeros", + torch.zeros(1, 10, 10, 10, dtype=torch.int32), + torch.zeros(1, 10, 10, 10, dtype=torch.int32), + ), + DataTuple( + "ones", + torch.ones(10, 10, 10, dtype=torch.int8), + torch.ones(10, 10, 10, dtype=torch.int8), + ), + DataTuple( + "rand_rank2", + torch.randint(-128, 127, (10, 10), dtype=torch.int8), + torch.randint(-128, 127, (10, 10), dtype=torch.int8), + ), + DataTuple( + "rand_rank4", + torch.randint(-128, -127, (1, 10, 10, 10), dtype=torch.int8), + torch.randint(-128, 127, (1, 10, 10, 10), dtype=torch.int8), + ), +] + + +ops: list[OpTuple] = [ + OpTuple("and", And()), + OpTuple("or", Or()), + OpTuple("xor", Xor()), +] + +full_test_suite = [] +for op in ops: + for test_data in test_data_suite: + full_test_suite.append( + ( + f"{op.name}_{test_data.name}", + op.operator, + test_data.tensor1, + test_data.tensor2, + ) + ) + +del test_data +del ops + + +class TestBitwise(unittest.TestCase): + + def _test_bitwise_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor, torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_bitwise_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor, torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80+BI", custom_path="local_bin/bitwise" + ), + ) + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_bitwise_tosa_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + # Tests that we don't delegate these ops since they are not supported on U55. + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(), + ) + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) + ) + + def _test_bitwise_tosa_u85_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u85_compile_spec(), + ) + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + ) + if conftest.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs(inputs=test_data) + + @parameterized.expand(full_test_suite) + def test_tosa_MI( + self, + test_name: str, + operator: Callable, + tensor1: torch.Tensor, + tensor2: torch.Tensor, + ): + self._test_bitwise_tosa_MI_pipeline(operator, (tensor1, tensor2)) + + @parameterized.expand(full_test_suite) + def test_tosa_BI( + self, + test_name: str, + operator: Callable, + tensor1: torch.Tensor, + tensor2: torch.Tensor, + ): + self._test_bitwise_tosa_BI_pipeline(operator, (tensor1, tensor2)) + + @parameterized.expand(full_test_suite) + def test_tosa_u55_BI( + self, + test_name: str, + operator: Callable, + tensor1: torch.Tensor, + tensor2: torch.Tensor, + ): + self._test_bitwise_tosa_u55_BI_pipeline(operator, (tensor1, tensor2)) + + @parameterized.expand(full_test_suite) + def test_tosa_u85_BI( + self, + test_name: str, + operator: Callable, + tensor1: torch.Tensor, + tensor2: torch.Tensor, + ): + self._test_bitwise_tosa_u85_BI_pipeline(operator, (tensor1, tensor2))