diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index bc53606cab6..245085ecd5a 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -21,6 +21,9 @@ from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa from .convert_minmax_pass import ConvertMinMaxPass # noqa +from .convert_permute_singleton_to_view_pass import ( # noqa + ConvertPermuteSingletonToViewPass, +) from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa from .convert_to_clamp import ConvertToClampPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index a086d23dc40..9157f7d6a4b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -27,6 +27,7 @@ ConvertIntPowToMuls, ConvertMinMaxPass, ConvertMmToBmmPass, + ConvertPermuteSingletonToViewPass, ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, @@ -234,6 +235,7 @@ def _tosa_pipeline( self.add_pass(CastToInt32Pass()) self.add_pass(BroadcastArgsPass()) + self.add_pass(ConvertPermuteSingletonToViewPass()) self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) self.add_pass(DecomposeConv2dWithInt16ActivationPass()) diff --git a/backends/arm/_passes/convert_permute_singleton_to_view_pass.py b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py new file mode 100644 index 00000000000..aca0e88c787 --- /dev/null +++ b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py @@ -0,0 +1,62 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Sequence, Set, Tuple, Type + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +from torch._ops import OpOverload + + +_PERMUTE_TARGETS: Tuple[OpOverload, ...] = ( + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, +) + + +class ConvertPermuteSingletonToViewPass(ExportPass): + """Replace permutations that only move singleton axes with a reshape. + + Examples: + x = rand(1,1,1,4) + y = permute(x, (0,3,1,2)) + + becomes: + x = rand(1,1,1,4) + y = view_copy(x, (1,4,1,1)) + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call_operator(self, op, args, kwargs, meta): + if op not in _PERMUTE_TARGETS: + return super().call_operator(op, args, kwargs, meta) + + input_tensor = args[0].data + permutation = args[1] + if not is_singleton_permutation(input_tensor.shape, permutation): + return super().call_operator(op, args, kwargs, meta) + + output_shape = meta["val"].shape + view_args = (args[0], output_shape) + return super().call_operator( + exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta + ) + + +def is_singleton_permutation(shape: Sequence[int], permutation: Sequence[int]) -> bool: + """ + Treat as a view only when non-singleton axes keep their order; singleton + axes may move freely since they carry no data volume. + """ + rank = len(shape) + normalized_perm = [d % rank for d in permutation] + + non_singleton_axes = [i for i, size in enumerate(shape) if size != 1] + permuted_non_singleton_axes = [axis for axis in normalized_perm if shape[axis] != 1] + + return permuted_non_singleton_axes == non_singleton_axes diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index f6fdada7d52..6c171e101aa 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -18,6 +18,9 @@ import torch.fx as fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.convert_permute_singleton_to_view_pass import ( + is_singleton_permutation, +) from executorch.backends.arm._passes.insert_table_ops import TableOps from executorch.backends.arm.operators.op_permute import transform_permutation_vector from executorch.backends.arm.tosa.utils import tosa_shape @@ -430,10 +433,17 @@ def _permute_constraint_i8_i16( ) -> bool: """Return True if permutation meets i8/i16 constraints.""" N, H, W, C = nhwc_shape + + if is_singleton_permutation(nhwc_shape, permutation): + return True + match permutation: case (0, 1, 2, 3): # NHWC -> NHWC return True - case (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2): # NHWC -> NWHC, NHCW, NCWH + case ( + (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2) | (0, 2, 3, 1) | (0, 3, 2, 1) + ): + # NHWC -> NWHC, NHCW, NCWH, NCHW, NCHW -> NHWC return N * H <= 65536 and W <= 65536 and C <= 65536 case _: return self.axes_product(nhwc_shape) <= 65536 diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index c9fe32bf86c..8938ebcc27e 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -38,6 +38,10 @@ "rank_4": lambda: (torch.rand(1, 5, 1, 10), [0, 2, 3, 1]), "rank_4_2": lambda: (torch.rand(1, 2, 5, 10), [1, 0, 2, 3]), "rank_4_3": lambda: (torch.rand(1, 10, 10, 5), [2, 0, 1, 3]), + "rank_4_large": lambda: (torch.rand(2, 8, 64, 65), [0, 2, 3, 1]), + "rank_3_large": lambda: (torch.rand(16, 64, 65), [1, 2, 0]), + "reshape_large_1": lambda: (torch.rand(1, 1, 65537), [0, 2, 1]), + "reshape_large_2": lambda: (torch.rand(65537, 1, 1), [1, 2, 0]), } diff --git a/backends/arm/test/passes/test_convert_permute_singleton_to_view_pass.py b/backends/arm/test/passes/test_convert_permute_singleton_to_view_pass.py new file mode 100644 index 00000000000..eb395403e3f --- /dev/null +++ b/backends/arm/test/passes/test_convert_permute_singleton_to_view_pass.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm._passes import ConvertPermuteSingletonToViewPass +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] + + +class PermuteSingletonAxesModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(0, 2, 3, 1) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 1, 3, 4),) + + +def test_convert_permute_singleton_to_view_applies(): + module = PermuteSingletonAxesModule() + pipeline = PassPipeline[input_t]( + module, + PermuteSingletonAxesModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + ], + pass_list=[ConvertPermuteSingletonToViewPass], + ) + pipeline.run() + + +class PermuteNonSingletonModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(0, 2, 1) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 3, 4),) + + +def test_convert_permute_singleton_to_view_skips_non_singleton(): + module = PermuteNonSingletonModule() + pipeline = PassPipeline[input_t]( + module, + PermuteNonSingletonModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ], + pass_list=[ConvertPermuteSingletonToViewPass], + ) + pipeline.run() + + +class PermuteSameSizedNonSingletonModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(2, 1, 0) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 1, 2),) + + +def test_convert_permute_singleton_to_view_skips_same_sized_non_singleton(): + module = PermuteSameSizedNonSingletonModule() + pipeline = PassPipeline[input_t]( + module, + PermuteSameSizedNonSingletonModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ], + pass_list=[ConvertPermuteSingletonToViewPass], + ) + pipeline.run()