diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 2ac23b0e91b..c6895cce492 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -6,7 +6,6 @@ # pyre-unsafe from . import ( # noqa - bitwise_support, convolution_support, pool_2d_support, reduce_sum_support, diff --git a/backends/arm/operator_support/bitwise_support.py b/backends/arm/operator_support/bitwise_support.py deleted file mode 100644 index e0604622064..00000000000 --- a/backends/arm/operator_support/bitwise_support.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch.fx as fx -from executorch.backends.arm.operator_support.tosa_supported_operators import ( - register_tosa_support_check, - SupportedTOSAOperatorCheck, -) -from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification -from executorch.exir.dialects._ops import ops as exir_ops - - -@register_tosa_support_check -class BitwiseSupported(SupportedTOSAOperatorCheck): - targets = [ - exir_ops.edge.aten.bitwise_and.Tensor, - exir_ops.edge.aten.bitwise_or.Tensor, - exir_ops.edge.aten.bitwise_xor.Tensor, - ] - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+BI"), - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): - # U55 case, Vela 4.2.0 (25.02 release) - if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: - return False - - return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 6fe70aa696c..cff64fe95ce 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -11,13 +11,13 @@ from typing import final, Optional, Sequence, Type import torch - import torch.fx as fx + from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( FuseQuantizedActivationPass, ) -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -90,6 +90,7 @@ def tosa_support_factory( if not tosa_spec.support_float(): negative_checks.append(NeedsDecompositionCheck()) negative_checks.append(CheckProperQuantization()) + negative_checks.append(EthosU55NotSupported(tosa_spec)) return chain( any_chain( BaseTOSASupportList(), @@ -111,6 +112,9 @@ def is_node_supported( supported = node.op == "call_function" and node.target in [ exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, exir_ops.edge.aten.expand_copy.default, exir_ops.edge.aten.cat.default, exir_ops.edge.aten.clamp.default, @@ -170,6 +174,31 @@ def is_node_supported( return supported +class EthosU55NotSupported(OperatorSupportBase): + """ + Certain operators are not supported on U55. These are listed in `unsupported` in + is_node_supported(). + """ + + def __init__(self, tosa_spec: TosaSpecification): + self.tosa_spec = tosa_spec + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: + unsupported_ops = [ + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + ] + + if node.target in unsupported_ops: + return False + + return True + + class NeedsDecompositionCheck(OperatorSupportBase): """ Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding