diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index 225efeab01f..983aa091eec 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -384,3 +384,63 @@ def is_node_supported( return False return True + + +class EthosU55CastCheck(OperatorSupportBase): + """Reject unsupported casts on U55. + + U55 does not support casting from INT32 or any casts involving BOOL. Note that + casting from one dtype to the same dtype is a no-op and is supported. + + + Attributes: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ + + targets = [ + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + ] + + def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ + super().__init__() + self.reporter = reporter + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + """Return True if the node satisfies the cast constraints of U55. + + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node to check. + + Returns: + bool: True if supported; otherwise, False. + + """ + if node.target not in self.targets: + return True + input_dtype = get_first_fake_tensor(node.all_input_nodes[0]).dtype + output_dtype = get_first_fake_tensor(node).dtype + if input_dtype == output_dtype: + # This is ok as this will not result in a cast + return True + if input_dtype in (torch.bool, torch.int32): + self.reporter.report_reject( + node, f"Casting from {input_dtype} is not supported on U55." + ) + return False + if output_dtype in (torch.bool,): + self.reporter.report_reject( + node, f"Casting to {output_dtype} is not supported on U55." + ) + return False + + return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index b580fbb9a9a..86c53e4aff1 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -21,6 +21,7 @@ from executorch.backends.arm._passes.insert_table_ops import TableOps from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.ethos_u55_support import ( + EthosU55CastCheck, EthosU55DtypeSupport, EthosU55NotSupported, EthosU55TransposeCheck, @@ -141,6 +142,7 @@ def tosa_support_factory( negative_checks.append(EthosU55DtypeSupport(reporter)) negative_checks.append(EthosU55TransposeCheck(reporter)) negative_checks.append(EthosU55ViewCheck(reporter)) + negative_checks.append(EthosU55CastCheck(reporter)) return chain( reporter.wrap_check( diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index 5c01788c805..42d136c52c1 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -244,3 +244,32 @@ def test_to_tosa_INT_not_delegated_REDUNDANT_CAST(test_data: Tuple): non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty ) pipeline.run() + + +_TO_COPY_DATA_INT_U55_REJECT = { + "rand_bool_int8": lambda: ( + torch.randint(0, 2, (1, 2, 3, 4), dtype=torch.bool), + torch.int8, + ), + "rand_int16_bool": lambda: ( + torch.randint(-1000, 1000, (1, 2, 3, 4), dtype=torch.int16), + torch.bool, + ), + "rand_int32_int8": lambda: ( + torch.randint(-1000, 1000, (1, 2, 3, 4), dtype=torch.int32), + torch.int8, + ), +} + + +@common.parametrize("test_data", _TO_COPY_DATA_INT_U55_REJECT) +def test_to_u55_INT(test_data: Tuple): + test_tensor, new_dtype = test_data() + pipeline = OpNotSupportedPipeline[input_t1]( + Cast(new_dtype), + (test_tensor,), + u55_subset=True, + quantize=True, + non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty + ) + pipeline.run()