Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
from .convert_minmax_pass import ConvertMinMaxPass # noqa
from .convert_permute_singleton_to_view_pass import ( # noqa
ConvertPermuteSingletonToViewPass,
)
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
from .convert_to_clamp import ConvertToClampPass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ConvertIntPowToMuls,
ConvertMinMaxPass,
ConvertMmToBmmPass,
ConvertPermuteSingletonToViewPass,
ConvertSplitToSlicePass,
ConvertSqueezesToViewPass,
ConvertToClampPass,
Expand Down Expand Up @@ -234,6 +235,7 @@ def _tosa_pipeline(
self.add_pass(CastToInt32Pass())
self.add_pass(BroadcastArgsPass())

self.add_pass(ConvertPermuteSingletonToViewPass())
self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
Expand Down
62 changes: 62 additions & 0 deletions backends/arm/_passes/convert_permute_singleton_to_view_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Sequence, Set, Tuple, Type

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

from torch._ops import OpOverload


_PERMUTE_TARGETS: Tuple[OpOverload, ...] = (
exir_ops.edge.aten.permute.default,
exir_ops.edge.aten.permute_copy.default,
)


class ConvertPermuteSingletonToViewPass(ExportPass):
"""Replace permutations that only move singleton axes with a reshape.

Examples:
x = rand(1,1,1,4)
y = permute(x, (0,3,1,2))

becomes:
x = rand(1,1,1,4)
y = view_copy(x, (1,4,1,1))
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call_operator(self, op, args, kwargs, meta):
if op not in _PERMUTE_TARGETS:
return super().call_operator(op, args, kwargs, meta)

input_tensor = args[0].data
permutation = args[1]
if not is_singleton_permutation(input_tensor.shape, permutation):
return super().call_operator(op, args, kwargs, meta)

output_shape = meta["val"].shape
view_args = (args[0], output_shape)
return super().call_operator(
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
)


def is_singleton_permutation(shape: Sequence[int], permutation: Sequence[int]) -> bool:
"""
Treat as a view only when non-singleton axes keep their order; singleton
axes may move freely since they carry no data volume.
"""
rank = len(shape)
normalized_perm = [d % rank for d in permutation]

non_singleton_axes = [i for i, size in enumerate(shape) if size != 1]
permuted_non_singleton_axes = [axis for axis in normalized_perm if shape[axis] != 1]

return permuted_non_singleton_axes == non_singleton_axes
12 changes: 11 additions & 1 deletion backends/arm/operator_support/ethos_u55_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import torch.fx as fx

from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.convert_permute_singleton_to_view_pass import (
is_singleton_permutation,
)
from executorch.backends.arm._passes.insert_table_ops import TableOps
from executorch.backends.arm.operators.op_permute import transform_permutation_vector
from executorch.backends.arm.tosa.utils import tosa_shape
Expand Down Expand Up @@ -430,10 +433,17 @@ def _permute_constraint_i8_i16(
) -> bool:
"""Return True if permutation meets i8/i16 constraints."""
N, H, W, C = nhwc_shape

if is_singleton_permutation(nhwc_shape, permutation):
return True

match permutation:
case (0, 1, 2, 3): # NHWC -> NHWC
return True
case (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2): # NHWC -> NWHC, NHCW, NCWH
case (
(0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2) | (0, 2, 3, 1) | (0, 3, 2, 1)
):
# NHWC -> NWHC, NHCW, NCWH, NCHW, NCHW -> NHWC
return N * H <= 65536 and W <= 65536 and C <= 65536
case _:
return self.axes_product(nhwc_shape) <= 65536
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/test/ops/test_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
"rank_4": lambda: (torch.rand(1, 5, 1, 10), [0, 2, 3, 1]),
"rank_4_2": lambda: (torch.rand(1, 2, 5, 10), [1, 0, 2, 3]),
"rank_4_3": lambda: (torch.rand(1, 10, 10, 5), [2, 0, 1, 3]),
"rank_4_large": lambda: (torch.rand(2, 8, 64, 65), [0, 2, 3, 1]),
"rank_3_large": lambda: (torch.rand(16, 64, 65), [1, 2, 0]),
"reshape_large_1": lambda: (torch.rand(1, 1, 65537), [0, 2, 1]),
"reshape_large_2": lambda: (torch.rand(65537, 1, 1), [1, 2, 0]),
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch

from executorch.backends.arm._passes import ConvertPermuteSingletonToViewPass
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline

input_t = Tuple[torch.Tensor]


class PermuteSingletonAxesModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute(0, 2, 3, 1)

@staticmethod
def input() -> input_t:
return (torch.randn(2, 1, 3, 4),)


def test_convert_permute_singleton_to_view_applies():
module = PermuteSingletonAxesModule()
pipeline = PassPipeline[input_t](
module,
PermuteSingletonAxesModule.input(),
quantize=False,
ops_before_pass={
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
},
ops_after_pass={
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
},
ops_not_after_pass=[
"executorch_exir_dialects_edge__ops_aten_permute_copy_default",
],
pass_list=[ConvertPermuteSingletonToViewPass],
)
pipeline.run()


class PermuteNonSingletonModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute(0, 2, 1)

@staticmethod
def input() -> input_t:
return (torch.randn(2, 3, 4),)


def test_convert_permute_singleton_to_view_skips_non_singleton():
module = PermuteNonSingletonModule()
pipeline = PassPipeline[input_t](
module,
PermuteNonSingletonModule.input(),
quantize=False,
ops_before_pass={
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
},
ops_after_pass={
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
},
ops_not_after_pass=[
"executorch_exir_dialects_edge__ops_aten_view_copy_default",
],
pass_list=[ConvertPermuteSingletonToViewPass],
)
pipeline.run()


class PermuteSameSizedNonSingletonModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute(2, 1, 0)

@staticmethod
def input() -> input_t:
return (torch.randn(2, 1, 2),)


def test_convert_permute_singleton_to_view_skips_same_sized_non_singleton():
module = PermuteSameSizedNonSingletonModule()
pipeline = PassPipeline[input_t](
module,
PermuteSameSizedNonSingletonModule.input(),
quantize=False,
ops_before_pass={
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
},
ops_after_pass={
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
},
ops_not_after_pass=[
"executorch_exir_dialects_edge__ops_aten_view_copy_default",
],
pass_list=[ConvertPermuteSingletonToViewPass],
)
pipeline.run()
Loading