Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backends/arm/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ runtime.python_library(
"//executorch/backends/arm:common",
"//executorch/backends/arm/tosa:utils",
"//executorch/backends/arm/tosa/dialect:lib",
"//executorch/backends/transforms:fuse_cascaded_transpose_or_permute_ops",
"//executorch/backends/transforms:fuse_cascaded_view_ops",
"//executorch/backends/transforms:fuse_transpose_or_permute_op_pairs_pass",
"//executorch/backends/transforms:remove_permutes_around_elementwise_ops",
"//executorch/backends/transforms:postpone_permute_below_squeeze_view",
"//executorch/backends/transforms:replace_nop_transpose_or_permute_with_view",
"//executorch/exir:lib",
],
)
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from . import arm_pass_utils # noqa
from .arm_pass import ArmPass # noqa # usort: skip
from .accumulate_index_put_pass import AccumulateIndexPutPass # noqa
from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa
from .broadcast_args_pass import BroadcastArgsPass # noqa
from .canonicalize_gather_pass import CanonicalizeGatherPass # noqa
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
Expand Down Expand Up @@ -165,7 +164,6 @@
from .rewrite_upsample import RewriteUpsamplePass # noqa
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
from .replace_inf_and_limit_values_pass import ( # noqa # usort: skip
Expand Down
28 changes: 0 additions & 28 deletions backends/arm/_passes/annotate_output_dim_order_pass.py

This file was deleted.

57 changes: 46 additions & 11 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
Expand All @@ -12,7 +12,6 @@

from executorch.backends.arm._passes import (
AccumulateIndexPutPass,
AnnotateOutputDimOrderPass,
BroadcastArgsPass,
CanonicalizeGatherPass,
CastInt64BuffersToInt32Pass,
Expand Down Expand Up @@ -44,7 +43,6 @@
DecomposeAtanPass,
DecomposeAvgPool2dPass,
DecomposeBatchNormNoStatsPass,
DecomposeConvWithInt16ActivationPass,
DecomposeCoshPass,
DecomposeCosineSimilarityPass,
DecomposeCumsumPass,
Expand All @@ -58,7 +56,6 @@
DecomposeFloorDividePass,
DecomposeGeluPass,
DecomposeGluPass,
DecomposeGroupedConvPass,
DecomposeGroupNormPass,
DecomposeGruPass,
DecomposeIndexCopyPass,
Expand Down Expand Up @@ -141,7 +138,6 @@
RewriteUpsamplePass,
ScalarsToAttributePass,
SizeAdjustInputPass,
ToTosaMemoryFormatPass,
UnsqueezeBeforeRepeatPass,
UnsqueezeScalarPlaceholdersPass,
)
Expand All @@ -157,7 +153,26 @@
TosaLoweringContext,
TosaSpecification,
)
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
FuseCascadedTransposeOrPermuteOps,
)
from executorch.backends.transforms.fuse_cascaded_view_ops import (
FuseCascadedViewOps,
)
from executorch.backends.transforms.fuse_transpose_or_permute_op_pairs_pass import (
FuseTransposeOrPermuteOpPairsPass,
)
from executorch.backends.transforms.postpone_permute_below_squeeze_view import (
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
)
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
RemovePermutesAroundElementwiseOps,
)
from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import (
ReplaceNopTransposeOrPermuteWithViewPass,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_manager import PassManager
from torch._export.utils import _get_shape_env_from_gm
Expand Down Expand Up @@ -385,9 +400,6 @@
# Allow subclasses to configure pass insertions before building pipeline
self._configure_pass_insertions(exported_program)

# Preprocessing passes
self.add_pass(AnnotateOutputDimOrderPass())

# Node transformation passes (pre q/dq folding)
self.add_passes(
[
Expand Down Expand Up @@ -455,7 +467,6 @@
DecomposeFloorDividePass(),
DecomposeGeluPass(),
DecomposeAddSubAlphaPass(),
DecomposeGroupedConvPass(),
DecomposeUnfoldToGatherPass(),
DecomposeEmbeddingPass(),
DecomposeIndexSelectToGatherPass(),
Expand Down Expand Up @@ -518,7 +529,6 @@
ConvertPermuteSingletonToViewPass(),
RewriteHighRankSingletonPermutePass(),
FuseViewCopyTransformPass(),
DecomposeConvWithInt16ActivationPass(),
DecomposeSumPass(),
InsertTableOpsPass(exported_program),
]
Expand All @@ -532,7 +542,6 @@
RewriteConvPass(exported_program),
RewriteMatmulPass(),
RewritePadPass(),
RewriteSlicePass(),
InsertConstShapesPass(),
]
)
Expand All @@ -542,14 +551,40 @@
[
CastInt64BuffersToInt32Pass(exported_program),
FuseEqualPlaceholdersPass(exported_program),
FuseConstantArgsPass(exported_program),
FuseConsecutiveConcatShapesPass(),
ToTosaMemoryFormatPass(exported_program),
RemoveNoopPass(),
InsertRescalePass(),
InsertDataLayoutCastsPass(),
]
)

# Additional optimization passes for permutes
# Fuse identity permute pairs across RESCALE ops
fuse_pairs = FuseTransposeOrPermuteOpPairsPass()
fuse_pairs.bypass_ops = fuse_pairs.bypass_ops | {
exir_ops.backend.tosa.RESCALE.default,
}

# Remove permutes around elementwise ops including RESCALE
remove_around = RemovePermutesAroundElementwiseOps()
remove_around.permutable_ops = remove_around.permutable_ops | {
exir_ops.backend.tosa.RESCALE.default,
}

self.add_passes(
[
remove_around,
RewriteSlicePass(),
fuse_pairs,
ReplaceNopTransposeOrPermuteWithViewPass(),
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
FuseCascadedTransposeOrPermuteOps(),
FuseCascadedViewOps(),
InsertConstShapesPass(),
]
)

# Apply all pass insertions once after all passes are collected
self._apply_pass_insertions()

Expand Down
5 changes: 0 additions & 5 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,6 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value):
raise RuntimeError("Invalid type")


def get_output_dim_orders(graph_module):
output_node = graph_module.graph.output_node()
return [get_first_fake_tensor(node).dim_order() for node in output_node.args[0]]


def is_nested_control_flow_graph(graph_module: GraphModule) -> bool:
"""Returns True if graph_module is a nested control-flow graph."""

Expand Down
170 changes: 120 additions & 50 deletions backends/arm/_passes/rewrite_avg_pool2d_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -7,69 +7,139 @@

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
adjust_pooling_pad_if_needed,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_base import ExportPass, PassResult

from .fuse_constant_ops_pass import ComputeConstantOpsAOTPass

_NCHW_TO_NHWC = [0, 2, 3, 1]
_NHWC_TO_NCHW = [0, 3, 1, 2]


class RewriteAvgPool2dPass(ArmPass):
"""Rewrite aten.avg_pool2d calls to TOSA AVG_POOL2D op."""
"""Rewrite aten.avg_pool2d calls to TOSA AVG_POOL2D op with NHWC layout."""

# Target the original avg_pool2d operator
targeted_ops = {exir_ops.edge.aten.avg_pool2d.default}
_passes_required_after: Set[Type[ExportPass]] = {
ComputeConstantOpsAOTPass,
}

def call_operator(self, op, args, kwargs, meta, updated=False):

# Only rewrite avg_pool2d
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta, updated)

x = args[0]
pad_h, pad_w = args[3]
# Make sure pad corresponds to TOSA
pad = [pad_h, pad_w, pad_h, pad_w]

_, _, h, w = x.data.shape
kernel_h, kernel_w = args[1]
stride_h, stride_w = args[2]

ceil_mode = args[4] if len(args) > 4 else False

# Adjust padding if necessary
pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode)
pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode)

# Materialize zero-point constants
in_qparams = meta.data.get("input_qparams", {})
in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0
# Materialize input zero-point as a scalar tensor
input_zp = super().call_scalar(in_zp_val, meta)

out_qparams = meta.data.get("output_qparams", {})
out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0
# Materialize output zero-point as a scalar tensor
output_zp = super().call_scalar(out_zp_val, meta)

# Determine accumulator dtype for AVG_POOL2D: INT32 for integer inputs, FP32 otherwise
if x.data.dtype in (torch.int8, torch.int16):
acc_type = torch.int32
else:
acc_type = torch.float32

tosa_args = (args[0], input_zp, output_zp, *args[1:3], pad, acc_type)

# Emit TOSA AVG_POOL2D with normalized args
return super().call_operator(
exir_ops.backend.tosa.AVG_POOL2D.default,
tosa_args,
{},
meta,
True,
@staticmethod
def _insert_permute(graph_module, anchor_node, input_node, perm, before=True):
ctx = (
graph_module.graph.inserting_before(anchor_node)
if before
else graph_module.graph.inserting_after(anchor_node)
)
with ctx:
return create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.permute_copy.default,
args=(input_node, perm),
from_node=input_node,
)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False

for node in list(graph_module.graph.nodes):
if node.op != "call_function" or node.target not in self.targeted_ops:
continue

modified = True
x = node.args[0]

pad_h, pad_w = node.args[3]
pad = [pad_h, pad_w, pad_h, pad_w]

input_fake = get_first_fake_tensor(x)
_, _, h, w = input_fake.shape
kernel_h, kernel_w = node.args[1]
stride_h, stride_w = node.args[2]

ceil_mode = node.args[4] if len(node.args) > 4 else False

pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode)
pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode)

# Determine zero-points and accumulator type
in_qparams = node.meta.get("input_qparams", {})
in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0

out_qparams = node.meta.get("output_qparams", {})
out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0

if input_fake.dtype in (torch.int8, torch.int16):
acc_type = torch.int32
else:
acc_type = torch.float32

# Insert NCHW → NHWC permute on input
x_permuted = self._insert_permute(
graph_module, node, x, _NCHW_TO_NHWC, before=True
)

# Materialize zp scalars as graph constants using aten.full with
# explicit dtype matching the input tensor. This ensures the
# pre-computed buffer placeholders carry the correct type for
# INT-only TOSA profiles (avoids defaulting to float32).
zp_kwargs = {"dtype": input_fake.dtype, "device": input_fake.device}
with graph_module.graph.inserting_before(node):
input_zp_node = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.full.default,
args=((1,), in_zp_val),
kwargs=zp_kwargs,
from_node=node,
)
output_zp_node = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.full.default,
args=((1,), out_zp_val),
kwargs=zp_kwargs,
from_node=node,
)

kernel = list(node.args[1])
stride = list(node.args[2])

tosa_args = (x_permuted, input_zp_node, output_zp_node, kernel, stride, pad, acc_type)

# Create TOSA AVG_POOL2D node
with graph_module.graph.inserting_after(node):
tosa_op = create_node(
graph=graph_module.graph,
op_target=exir_ops.backend.tosa.AVG_POOL2D.default,
args=tosa_args,
from_node=node,
inherit_qparams=True,
)

# Compute correct NHWC FakeTensor
input_fake_nhwc = input_fake.permute(_NCHW_TO_NHWC)
input_zp_fake = torch.tensor(in_zp_val, dtype=input_fake.dtype)
output_zp_fake = torch.tensor(out_zp_val, dtype=input_fake.dtype)
tosa_node_fake = exir_ops.backend.tosa.AVG_POOL2D.default(
input_fake_nhwc, input_zp_fake, output_zp_fake, kernel, stride, pad, acc_type
)
tosa_op.meta["val"] = tosa_node_fake

# Insert NHWC → NCHW permute on output
output_permute = self._insert_permute(
graph_module, tosa_op, tosa_op, _NHWC_TO_NCHW, before=False
)

node.replace_all_uses_with(output_permute)
graph_module.graph.erase_node(node)

if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, modified)
Loading
Loading