Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions backends/arm/operator_support/ethos_u55_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,63 @@ def is_node_supported(
return False

return True


class EthosU55CastCheck(OperatorSupportBase):
"""Reject unsupported casts on U55.

U55 does not support casting from INT32 or any casts involving BOOL. Note that
casting from one dtype to the same dtype is a no-op and is supported.


Attributes:
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.

"""

targets = [
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
]

def __init__(self, reporter: WhyNoPartitionReporter):
"""Initialize the check with a reporter.

Args:
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.

"""
super().__init__()
self.reporter = reporter

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
"""Return True if the node satisfies the cast constraints of U55.

Args:
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
node (fx.Node): FX node to check.

Returns:
bool: True if supported; otherwise, False.

"""
if node.target not in self.targets:
return True
input_dtype = get_first_fake_tensor(node.all_input_nodes[0]).dtype
output_dtype = get_first_fake_tensor(node).dtype
if input_dtype == output_dtype:
# This is ok as this will not result in a cast
return True
if input_dtype in (torch.bool, torch.int32):
self.reporter.report_reject(
node, f"Casting from {input_dtype} is not supported on U55."
)
return False
if output_dtype in (torch.bool,):
self.reporter.report_reject(
node, f"Casting to {output_dtype} is not supported on U55."
)
return False

return True
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from executorch.backends.arm._passes.insert_table_ops import TableOps
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.backends.arm.operator_support.ethos_u55_support import (
EthosU55CastCheck,
EthosU55DtypeSupport,
EthosU55NotSupported,
EthosU55TransposeCheck,
Expand Down Expand Up @@ -141,6 +142,7 @@ def tosa_support_factory(
negative_checks.append(EthosU55DtypeSupport(reporter))
negative_checks.append(EthosU55TransposeCheck(reporter))
negative_checks.append(EthosU55ViewCheck(reporter))
negative_checks.append(EthosU55CastCheck(reporter))

return chain(
reporter.wrap_check(
Expand Down
29 changes: 29 additions & 0 deletions backends/arm/test/ops/test_to_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,32 @@ def test_to_tosa_INT_not_delegated_REDUNDANT_CAST(test_data: Tuple):
non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty
)
pipeline.run()


_TO_COPY_DATA_INT_U55_REJECT = {
"rand_bool_int8": lambda: (
torch.randint(0, 2, (1, 2, 3, 4), dtype=torch.bool),
torch.int8,
),
"rand_int16_bool": lambda: (
torch.randint(-1000, 1000, (1, 2, 3, 4), dtype=torch.int16),
torch.bool,
),
"rand_int32_int8": lambda: (
torch.randint(-1000, 1000, (1, 2, 3, 4), dtype=torch.int32),
torch.int8,
),
}


@common.parametrize("test_data", _TO_COPY_DATA_INT_U55_REJECT)
def test_to_u55_INT(test_data: Tuple):
test_tensor, new_dtype = test_data()
pipeline = OpNotSupportedPipeline[input_t1](
Cast(new_dtype),
(test_tensor,),
u55_subset=True,
quantize=True,
non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty
)
pipeline.run()
Loading