Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions backends/arm/operator_support/ethos_u55_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,69 @@ def is_node_supported(
shape_t = list[int]


class EthosU55ViewCheck(OperatorSupportBase):

def __init__(self, reporter: WhyNoPartitionReporter):
super().__init__()
self.reporter = reporter

def axes_product(self, nhwc_shape: shape_t) -> int:
product = 1
for axes in nhwc_shape:
product *= axes
return product

# TODO: Extend this check to comply with u55 restrictions
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
"""
Check whether a given view node is supported on U55.

Currently only checks dtypes and product of axes.

It is not the view operator itself that is not supported on U55. In order for the
view operator to be compatible with the channels-last format of TosaBackend,
transposes may need to be inserted before and after the view op. If that happens
and that transpose operator does not adhere to the limitations then it will
result in the following error:

CPU performance estimation for "Transpose" not implemented.
...
CPU operations are not supported for GraphAPI input

Args:
node: The FX node representing the view_copy operator.

Returns:
False if the operator is not support and True if it is supported.
"""
if not node.target == exir_ops.edge.aten.view_copy.default:
return True

shape = list(get_first_fake_tensor(node).shape)
dtype = _try_determine_dtype(node)
permutation = list(typing.cast(list[int], node.args[1]))

rank = len(shape)
if rank > 4:
if dtype == torch.int32:
self.reporter.report_reject(
node, f"No support for {permutation=} in int32."
)
return False

if dtype in (torch.int8, torch.int16):
if self.axes_product(shape) > 65536:
self.reporter.report_reject(
node,
f"No support for {shape=}, {dtype=}. Product of axes must be <65536",
)
return False

return True


class EthosU55TransposeCheck(OperatorSupportBase):

def __init__(self, reporter: WhyNoPartitionReporter):
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
EthosU55DtypeSupport,
EthosU55NotSupported,
EthosU55TransposeCheck,
EthosU55ViewCheck,
)
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand Down Expand Up @@ -133,6 +134,7 @@ def tosa_support_factory(
negative_checks.append(EthosU55NotSupported(reporter))
negative_checks.append(EthosU55DtypeSupport(reporter))
negative_checks.append(EthosU55TransposeCheck(reporter))
negative_checks.append(EthosU55ViewCheck(reporter))

return chain(
reporter.wrap_check(
Expand Down
20 changes: 20 additions & 0 deletions backends/arm/test/ops/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
OpNotSupportedPipeline,
TosaPipelineBI,
TosaPipelineMI,
)
Expand Down Expand Up @@ -44,6 +45,10 @@ class View(torch.nn.Module):
"rand_4d_2_4_same": lambda: (torch.rand(2, 3, 2, 3), (2, 3, 3, 2)),
}

rank_product_too_large = {
"rand_4d_large": lambda: (torch.rand(1, 49, 16, 128), (1, 16, 49, 128)),
}

def __init__(self, new_shape):
super().__init__()
self.new_shape = new_shape
Expand Down Expand Up @@ -104,6 +109,21 @@ def test_view_u55_BI(test_data: Tuple):
pipeline.run()


@common.parametrize("test_data", View.rank_product_too_large, xfails=xfails)
@common.XfailIfNoCorstone300
def test_view_u55_BI_not_delegated(test_data: Tuple):
test_tensor, new_shape = test_data()
pipeline = OpNotSupportedPipeline[input_t1](
View(new_shape),
(test_tensor,),
{"executorch_exir_dialects_edge__ops_aten_view_copy": 1},
n_expected_delegates=0,
quantize=True,
u55_subset=True,
)
pipeline.run()


@common.parametrize("test_data", View.needs_transpose_tests, xfails=xfails)
@common.XfailIfNoCorstone320
def test_view_u85_BI(test_data: Tuple):
Expand Down
Loading