Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# pyre-unsafe

from . import ( # noqa
bitwise_support,
convolution_support,
pool_2d_support,
reduce_sum_support,
Expand Down
33 changes: 0 additions & 33 deletions backends/arm/operator_support/bitwise_support.py

This file was deleted.

33 changes: 31 additions & 2 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from typing import final, Optional, Sequence, Type

import torch

import torch.fx as fx

from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
FuseQuantizedActivationPass,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
Expand Down Expand Up @@ -90,6 +90,7 @@ def tosa_support_factory(
if not tosa_spec.support_float():
negative_checks.append(NeedsDecompositionCheck())
negative_checks.append(CheckProperQuantization())
negative_checks.append(EthosU55NotSupported(tosa_spec))
return chain(
any_chain(
BaseTOSASupportList(),
Expand All @@ -111,6 +112,9 @@ def is_node_supported(
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.clamp.default,
Expand Down Expand Up @@ -170,6 +174,31 @@ def is_node_supported(
return supported


class EthosU55NotSupported(OperatorSupportBase):
"""
Certain operators are not supported on U55. These are listed in `unsupported` in
is_node_supported().
"""

def __init__(self, tosa_spec: TosaSpecification):
self.tosa_spec = tosa_spec

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
unsupported_ops = [
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
]

if node.target in unsupported_ops:
return False

return True


class NeedsDecompositionCheck(OperatorSupportBase):
"""
Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding
Expand Down
Loading