Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from .normalize_delegate_io_layout_pass import NormalizeDelegateIOLayoutPass # noqa
from .normalize_index_put_bool_index_tensor_pass import ( # noqa
NormalizeIndexPutBoolIndexTensorPass,
)
Expand Down
30 changes: 19 additions & 11 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from executorch.backends.arm._passes import (
AccumulateIndexPutPass,
AnnotateOutputDimOrderPass,
BroadcastArgsPass,
CanonicalizeGatherPass,
CastInt64BuffersToInt32Pass,
Expand Down Expand Up @@ -44,7 +43,6 @@
DecomposeAtanPass,
DecomposeAvgPool2dPass,
DecomposeBatchNormNoStatsPass,
DecomposeConvWithInt16ActivationPass,
DecomposeCoshPass,
DecomposeCosineSimilarityPass,
DecomposeCumsumPass,
Expand Down Expand Up @@ -117,6 +115,7 @@
InsertTableOpsPass,
MatchArgDtypePass,
MatchArgRanksPass,
NormalizeDelegateIOLayoutPass,
NormalizeIndexPutBoolIndexTensorPass,
NormalizeIndexPutNoneIndicesPass,
NormalizeWhileInitialArgsPass,
Expand All @@ -142,7 +141,6 @@
RewriteUpsamplePass,
ScalarsToAttributePass,
SizeAdjustInputPass,
ToTosaMemoryFormatPass,
UnsqueezeBeforeRepeatPass,
UnsqueezeScalarPlaceholdersPass,
)
Expand All @@ -158,6 +156,16 @@
TosaLoweringContext,
TosaSpecification,
)
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
FuseCascadedTransposeOrPermuteOps,
)
from executorch.backends.transforms.postpone_permute_below_squeeze_view import (
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
)

from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
RemovePermutesAroundElementwiseOps,
)
from executorch.exir import ExportedProgram
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_manager import PassManager
Expand Down Expand Up @@ -386,12 +394,10 @@ def _tosa_pipeline(
# Allow subclasses to configure pass insertions before building pipeline
self._configure_pass_insertions(exported_program)

# Preprocessing passes
self.add_pass(AnnotateOutputDimOrderPass())

# Node transformation passes (pre q/dq folding)
self.add_passes(
[
NormalizeDelegateIOLayoutPass(exported_program),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe a comment about that this needs to be first:ish and why?

FuseQuantizedActivationPass(),
RewriteBoolToFp32CastViaInt8Pass(),
CanonicalizeGatherPass(),
Expand Down Expand Up @@ -516,12 +522,9 @@ def _tosa_pipeline(
ConvertSqueezesToViewPass(),
CastToInt32Pass(),
BroadcastArgsPass(),
ConvertPermuteSingletonToViewPass(),
RewriteHighRankSingletonPermutePass(),
FuseViewCopyTransformPass(),
DecomposeConvWithInt16ActivationPass(),
DecomposeSumPass(),
InsertTableOpsPass(exported_program),
RemoveNoopPass(),
]
)

Expand All @@ -534,6 +537,12 @@ def _tosa_pipeline(
RewriteMatmulPass(),
RewritePadPass(),
RewriteSlicePass(),
FuseViewCopyTransformPass(),
RemovePermutesAroundElementwiseOps(),
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
FuseCascadedTransposeOrPermuteOps(),
ConvertPermuteSingletonToViewPass(),
RewriteHighRankSingletonPermutePass(),
InsertConstShapesPass(),
]
)
Expand All @@ -544,7 +553,6 @@ def _tosa_pipeline(
CastInt64BuffersToInt32Pass(exported_program),
FuseEqualPlaceholdersPass(exported_program),
FuseConsecutiveConcatShapesPass(),
ToTosaMemoryFormatPass(exported_program),
EnsureUniqueOutputNodesPass(),
RemoveNoopPass(),
InsertRescalePass(),
Expand Down
9 changes: 9 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,12 @@ def get_cond_while_submodules_nested(
}
# collect cond/while submodules (using mapping indices)
return _get_control_flow_submodules(graph_module, mapping)


def to_2tuple(value):
"""Normalizes scalars, and 1-element sequences to a tuple of length 2."""
if isinstance(value, int):
return (value, value)
if len(value) == 1:
return (value[0], value[0])
return tuple(value)
14 changes: 7 additions & 7 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def call_operator(self, op, args, kwargs, meta):
x_meta.data["output_qparams"] = {}

x = args[0]
x_unsqueezed_shape = list(x.data.shape) + [1]
x_unsqueezed_shape = list(x.data.shape[:-1]) + [1] + [x.data.shape[-1]]
x = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(x, x_unsqueezed_shape),
Expand All @@ -61,7 +61,7 @@ def call_operator(self, op, args, kwargs, meta):
w_meta.data["output_qparams"] = {}

w = args[1]
w_unsqueezed_shape = list(w.data.shape) + [1]
w_unsqueezed_shape = list(w.data.shape[:-1]) + [1] + [w.data.shape[-1]]
w = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(w, w_unsqueezed_shape),
Expand All @@ -74,11 +74,11 @@ def call_operator(self, op, args, kwargs, meta):
x,
w,
args[2],
args[3] + [1], # stride
args[4] + [0], # padding
args[5] + [1], # dilation
[1] + args[3], # stride
[0] + args[4], # padding
[1] + args[5], # dilation
args[6],
args[7] + [0],
[0] + args[7],
args[8],
)
x = super().call_operator(
Expand All @@ -88,7 +88,7 @@ def call_operator(self, op, args, kwargs, meta):
x_squeezed_meta = meta.copy()
x_squeezed_meta.data["input_qparams"] = {}
x_squeezed_meta.data["output_qparams"] = {}
x_squeezed_shape = list(x.data.shape)[:-1]
x_squeezed_shape = list(x.data.shape[:-2]) + [x.data.shape[-1]]
x = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(x, x_squeezed_shape),
Expand Down
137 changes: 137 additions & 0 deletions backends/arm/_passes/normalize_delegate_io_layout_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
is_param_node,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class NormalizeDelegateIOLayoutPass(ArmPass):
"""Adjust delegated boundary tensor shapes and insert permutes at I/O."""

_passes_required_after: Set[Type[ExportPass]] = set()

def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.exported_program = exported_program

@staticmethod
def _inverse_permutation(perm: tuple[int, ...]) -> tuple[int, ...]:
inverse = [0] * len(perm)
for idx, axis in enumerate(perm):
inverse[axis] = idx
return tuple(inverse)

@staticmethod
def _permute_shape(shape: torch.Size, perm: tuple[int, ...]) -> tuple[int, ...]:
return tuple(shape[axis] for axis in perm)

@staticmethod
def _is_identity_dim_order(dim_order: tuple[int, ...]) -> bool:
return dim_order == tuple(range(len(dim_order)))

def _normalize_input_layout(self, graph_module: torch.fx.GraphModule) -> bool:
modified = False
for node in graph_module.graph.nodes:
if node.op != "placeholder" or is_param_node(self.exported_program, node):
continue

input_fake = get_first_fake_tensor(node)
dim_order = input_fake.dim_order()
if self._is_identity_dim_order(dim_order):
continue

boundary_shape = self._permute_shape(input_fake.shape, dim_order)
node.meta["val"] = input_fake.reshape(boundary_shape)

transpose_perm = self._inverse_permutation(dim_order)
with graph_module.graph.inserting_after(node):
permute_node = create_node(
graph_module.graph,
exir_ops.edge.aten.permute_copy.default,
args=(node, list(transpose_perm)),
from_node=node,
)
permute_node.meta["val"] = exir_ops.edge.aten.permute_copy.default(
node.meta["val"], list(transpose_perm)
)

users = [user for user in node.users if user != permute_node]
for user in users:
user.replace_input_with(node, permute_node)

modified = True

return modified

def _rewrite_output_arg(
self, arg: Any, graph_module: torch.fx.GraphModule
) -> tuple[Any, bool]:
if isinstance(arg, torch.fx.Node):
output_fake = get_first_fake_tensor(arg)
dim_order = output_fake.dim_order()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't this change during the lowering?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can, so it is important that this is the first pass run during the lowering!

if self._is_identity_dim_order(dim_order):
return arg, False

with graph_module.graph.inserting_after(arg):
permute_node = create_node(
graph_module.graph,
exir_ops.edge.aten.permute_copy.default,
args=(arg, list(dim_order)),
from_node=arg,
)
permute_node.meta["val"] = exir_ops.edge.aten.permute_copy.default(
output_fake, list(dim_order)
)

return permute_node, True

if isinstance(arg, tuple):
modified = False
rewritten = []
for item in arg:
new_item, item_modified = self._rewrite_output_arg(item, graph_module)
rewritten.append(new_item)
modified = modified or item_modified
return tuple(rewritten), modified

if isinstance(arg, list):
modified = False
rewritten = []
for item in arg:
new_item, item_modified = self._rewrite_output_arg(item, graph_module)
rewritten.append(new_item)
modified = modified or item_modified
return rewritten, modified

return arg, False

def _normalize_output_layout(self, graph_module: torch.fx.GraphModule) -> bool:
output_node = graph_module.graph.output_node()
rewritten_outputs, modified = self._rewrite_output_arg(
output_node.args[0], graph_module
)
if modified:
output_node.args = (rewritten_outputs,)
return modified

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = self._normalize_input_layout(graph_module)
modified = self._normalize_output_layout(graph_module) or modified

if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
49 changes: 40 additions & 9 deletions backends/arm/_passes/rewrite_avg_pool2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import to_2tuple
from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER
from executorch.backends.arm.operators.operator_validation_utils import (
adjust_pooling_pad_if_needed,
)
Expand All @@ -32,19 +34,25 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
return super().call_operator(op, args, kwargs, meta, updated)

x = args[0]
pad_h, pad_w = args[3]
kernel = to_2tuple(args[1])

stride = to_2tuple(args[2]) if len(args) > 2 else ()
if not stride:
stride = kernel # default to kernel_size

pad_h, pad_w = to_2tuple(args[3]) if len(args) > 3 else (0, 0)
# Make sure pad corresponds to TOSA
pad = [pad_h, pad_w, pad_h, pad_w]

_, _, h, w = x.data.shape
kernel_h, kernel_w = args[1]
stride_h, stride_w = args[2]

ceil_mode = args[4] if len(args) > 4 else False

# Adjust padding if necessary
pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode)
pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode)
pad[1] = adjust_pooling_pad_if_needed(
x.data.shape[2], kernel[0], stride[0], pad[1], ceil_mode
)
pad[3] = adjust_pooling_pad_if_needed(
x.data.shape[3], kernel[1], stride[1], pad[3], ceil_mode
)

# Materialize zero-point constants
in_qparams = meta.data.get("input_qparams", {})
Expand All @@ -63,13 +71,36 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
else:
acc_type = torch.float32

tosa_args = (args[0], input_zp, output_zp, *args[1:3], pad, acc_type)
pre_permute = super().call_operator(
exir_ops.edge.aten.permute_copy.default,
(x, list(NHWC_ORDER)),
{},
meta,
updated=True,
)

tosa_args = (
pre_permute,
input_zp,
output_zp,
list(kernel),
list(stride),
pad,
acc_type,
)

# Emit TOSA AVG_POOL2D with normalized args
return super().call_operator(
tosa_avg_pool = super().call_operator(
exir_ops.backend.tosa.AVG_POOL2D.default,
tosa_args,
{},
meta,
True,
)
return super().call_operator(
exir_ops.edge.aten.permute_copy.default,
(tosa_avg_pool, list(NHWC_INVERSE_ORDER)),
{},
meta,
updated=True,
)
Loading
Loading