From 1cc2d53deb4e22dd982fca29d3059fbfdea6b72b Mon Sep 17 00:00:00 2001 From: George Gekov Date: Wed, 15 Apr 2026 17:37:00 +0100 Subject: [PATCH] Arm backend: Fix rejection criteria of TRANSPOSE from VIEW When delegating a VIEW for Ethos-U55, we were overly pessimistic whether we can delegate the TRANSPOSE that is needed for the NHWC -> NCHW or NCHW -> NHWC permutation. As a result, some RESHAPEs were left-over to the CPU when actually they could have been run on NPU. Signed-off-by: George Gekov Change-Id: I34cc3b38cf0dbb0ceee32ac5d0044805c4e1f085 --- .../arm/operator_support/ethos_u55_support.py | 62 ++++++++++++------- backends/arm/test/ops/test_view.py | 8 ++- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index c0c795c2cfa..c55d22a0f2c 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -10,6 +10,7 @@ """ import typing +from itertools import combinations import torch import torch.fx as fx @@ -281,20 +282,37 @@ def __init__(self, reporter: WhyNoPartitionReporter): _MAX_AXIS_PRODUCT = 65536 - def axes_product(self, shape: shape_t) -> int: - """Return the product of all axes in ``shape``. - + def _max_product_axis(self, shape: shape_t): + """ Args: shape (shape_t): Shape. Returns: - int: Product of the axis sizes. - + True if the TRANSPOSE can be run on the Ethos-U55 + False if the TRANSPOSE cannot be run on the Ethos-U55 + + For a tensor of rank N, the product of any combination of + N - 2 axis needs to be less than 65536. E.g. for rank 4 tensor, + N*H, N*W, N*C, H*W, H*C, W*C should all be lower than 65536 to + be able to run the TRANSPOSE on Ethos-U55. + The full TRANSPOSE requirements for the Ethos-U55 are listed in + https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/SUPPORTED_OPS.md """ - product = 1 - for axes in shape: - product *= axes - return product + rank = len(shape) + if rank < 3: + product = 1 + for idx in shape: + product *= idx + return product <= self._MAX_AXIS_PRODUCT + + else: + for axes in combinations(range(rank), rank - 2): + product = 1 + for idx in axes: + product *= shape[idx] + if product > self._MAX_AXIS_PRODUCT: + return False + return True def _check_rank_constraints( self, @@ -322,11 +340,11 @@ def _check_rank_constraints( output_rank = len(output_shape) if input_rank > 4: - if self.axes_product(input_shape) > self._MAX_AXIS_PRODUCT: + if not (self._max_product_axis(input_shape)): self.reporter.report_reject( node, f"Input may require transpose operator. No support for {input_shape=}, " - f"{dtype=}. Product of axes must be <={self._MAX_AXIS_PRODUCT}", + f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}", ) return False if dtype == torch.int32: @@ -337,12 +355,12 @@ def _check_rank_constraints( return False if output_rank > 4: - if self.axes_product(output_shape) > self._MAX_AXIS_PRODUCT: + if not (self._max_product_axis(output_shape)): shape = output_shape self.reporter.report_reject( node, f"Operator may require transpose operator. No support for {shape=}, " - f"{dtype=}. Product of axes must be <={self._MAX_AXIS_PRODUCT}", + f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}", ) return False if dtype == torch.int32: @@ -450,24 +468,22 @@ def _check_transpose_constraints( ) return False - if ( - needs_input_transpose - and self.axes_product(input_shape) > self._MAX_AXIS_PRODUCT - ): + # For TRANSPOSE originating from a VIEW, we know we will only do + # NHWC -> NCHW or NCHW -> NHWC permutations, hence we only need to validate + # these two TRANSPOSEs. For the general case of any permutation on TRANSPOSE, + # we reason via the checks in EthosU55TransposeCheck + if needs_input_transpose and not (self._max_product_axis(input_shape)): self.reporter.report_reject( node, f"Operator requires transpose operator. No support for {input_shape=}, " - f"{dtype=}. Product of axes must be <{self._MAX_AXIS_PRODUCT}", + f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}", ) return False - if ( - needs_output_transpose - and self.axes_product(output_shape) > self._MAX_AXIS_PRODUCT - ): + if needs_output_transpose and not (self._max_product_axis(output_shape)): self.reporter.report_reject( node, f"Operator requires transpose operator. No support for {output_shape=}, " - f"{dtype=}. Product of axes must be <{self._MAX_AXIS_PRODUCT}", + f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}", ) return False diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 2bee8ede0aa..e3b09cbd959 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -51,6 +51,8 @@ class View(torch.nn.Module): "rand_5d_5d": lambda: (torch.rand(1, 1, 4, 5, 6), (1, 1, 4, -1, 6)), "rand_5d_3d": lambda: (torch.rand(1, 1, 4, 5, 6), (2, 3, -1)), "rand_3d_5d": lambda: (torch.rand(4, 5, 6), (1, 1, 2, -1, 3)), + "rank4_rank3_large": lambda: (torch.rand(1, 256, 6, 48), (6, 48, 256)), + "rank5_rank4_large": lambda: (torch.rand(1, 256, 2, 3, 48), (1, 256, 6, 48)), } needs_transpose_tests_fp16 = { @@ -65,8 +67,7 @@ class View(torch.nn.Module): } rank_product_too_large = { - "rand_4d_large": lambda: (torch.rand(1, 49, 16, 128), (1, 16, 49, 128)), - "rand_5d_large": lambda: (torch.rand(2, 25, 16, 8, 64), (2, 16, 25, 8, 64)), + "rand_5d_large": lambda: (torch.rand(2, 256, 512, 8, 64), (2, 512, 256, 8, 64)), } def __init__(self, new_shape): @@ -116,6 +117,9 @@ def test_view_u55_INT(test_data: Tuple): aten_op, exir_ops=[], ) + pipeline.change_args( + "check_not.exir", ["executorch_exir_dialects_edge__ops_aten_view_copy_default"] + ) pipeline.run()