Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backends/arm/_passes/add_bias_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.transforms.utils import create_constant_placeholder

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from executorch.exir.pass_base import ExportPass, PassResult
from torch.export.graph_signature import InputKind


Expand All @@ -19,6 +21,8 @@ class AddBiasPass(ArmPass):
The bias is set to zero.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = (exir_ops.edge.aten.convolution.default,)

def call(self, graph_module):
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import itertools
import operator
from typing import cast, List
from typing import cast, List, Set, Type

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
Expand All @@ -29,6 +29,8 @@ class AnnotateDecomposedMatmulPass(ExportPass):
matmul-op (can be mm or bmm).
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def _match_partition_to_node(
self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node]
) -> torch.fx.Node:
Expand Down
7 changes: 6 additions & 1 deletion backends/arm/_passes/annotate_output_dim_order_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Set, Type

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders
from executorch.exir.pass_base import PassResult
from executorch.exir.pass_base import ExportPass, PassResult


class AnnotateOutputDimOrderPass(ArmPass):
Expand All @@ -14,6 +17,8 @@ class AnnotateOutputDimOrderPass(ArmPass):
for verifying that the dim order does not change unexpectedly in later passes.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call(self, graph_module):
output_node = graph_module.graph.output_node()
output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module)
Expand Down
33 changes: 32 additions & 1 deletion backends/arm/_passes/arm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# pyre-unsafe

import traceback
from typing import Optional
from abc import abstractmethod
from typing import List, Optional, Set, Type

import torch
from executorch.exir.pass_base import ExportPass, NodeMetadata
Expand All @@ -19,6 +20,36 @@ def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = No
super(ArmPass, self).__init__()
self.exported_program = exported_program

@property
@abstractmethod
def _passes_required_after(self) -> Set[Type[ExportPass]]:
"""The subclass defines passes that must run after it"""
pass

@staticmethod
def get_required_passes(pass_) -> List[str]:
"""
Returns the list of passes that must be run after this pass, sorted by name.
"""
if hasattr(pass_, "_passes_required_after"):
return sorted([ArmPass.get_name(p) for p in pass_._passes_required_after])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue I can see with this approach is transitive dependency and running into weird errors.

At a high level, do we see many cases where we run same pass multiple times in certain order for a given backend, or do something different, in terms of pass ordering, based on some flag or a model architecture?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not fully understanding the case of transitive dependencies. With this implementation, if we got the following scenario:

PassA:
    _passes_required_after = {PassB}

PassB:
    _passes_required_after = {PassC}

PassC:
    _passes_required_after = set()

Then PassA will transitively depend on PassC.

Then if create a pipeline where we remove PassB like this:

[
    PassA,
    PassC,
]

Then we will get the following error:

The following constraints for passes are not met:
    PassB must run after PassA

Is that not an reasonable way to handle it? I cannot think of a case where a missing transitive dependency will slip through here and not get caught by validate_constraints_mandatory.

Then regarding your second point: We have a long-term plan to change the passes such that they will always be added in the same order, but let themselves check whether they should run or not. This means that the order will always remain the same no matter which state/hardware profile we are in. Does that answer your question or maybe I'm not fully following?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that not an reasonable way to handle it?

Yeah I guess. Thanks. I was also thinking like,

PassA --(needs)--> Pass B
PassB --(needs)--> Pass C
PassC --(needs)--> Pass A

More broadly,

let themselves check whether they should run or not

I am failing to see the gain from this added complexity over a list[Pass] for a given backend. Mainly because this is not a user facing surface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are saying we don't have to maintain "per backend + config list of passe" perhaps that is a good enough reason to embrace this complexity.. But I doubt if we can get rid of backend "specialized" logic from the PassManager but happy to be proven wrong :)

Copy link
Collaborator Author

@martinlsm martinlsm Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @digantdesai for the thorough review and feedback!

Regarding the cycles, yes, this implementation would not be able to handle that. But we also want to avoid those kind of loops, so having circular dependencies such as that is something we want to avoid. I cannot say it's 100% future proof but it's unlikely that we would ever need cyclic dependencies I'm thinking.

Regarding the added complexity, perhaps it adds a bit yes; the idea is to uphold intended ordering of the passes in cases where a developer wants to change the ordering. It's not obvious what should come before/after what when there's so many of them. One could say that the unit tests should protect against improper reordering, but this feature gives us quicker and clearer feedback in those cases.

I see your argument about special cases might work against this design. In the worst case we would have to revisit this feature and try to extend/change it somehow. But we do hope to eliminate special cases and always list the passes in the same order at least.

Copy link
Contributor

@digantdesai digantdesai Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_passes_required_after being empty for almost all existing passes implies that a simple (ordered) list is sufficient :)
I don't want to block you here, and since this is inside Arm dir, I will let you make the call. Thanks for the discussion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I see your point, but the actual dependencies are added in a later patch that is on its way.

else:
return []

@staticmethod
def get_name(pass_) -> str:
"""
Returns the name of the pass.
"""
if isinstance(pass_, ExportPass):
return pass_.__class__.__name__
elif hasattr(pass_, "__name__"):
return pass_.__name__
else:
raise ValueError(
f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute."
)

def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
if not updated:
return super().call_operator(op, args, kwargs, meta)
Expand Down
33 changes: 32 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

# pyre-unsafe


from collections import defaultdict

import executorch.backends.arm.tosa.dialect # noqa: unused
from executorch.backends.arm._passes import (
AddBiasPass,
Expand Down Expand Up @@ -94,6 +97,7 @@
UnsqueezeScalarPlaceholdersPass,
)

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
Expand All @@ -115,6 +119,32 @@ def __init__(self, tosa_spec: TosaSpecification) -> None:
self.tosa_spec = tosa_spec
super().__init__()

def validate_constraints_mandatory(self):
"""
Validates that necessary passes have run before transforming to backend.

Note that this differs from the original validate_constraints function, which
only checks the order of passes.
"""
passes_to_run = defaultdict(list)

for current_pass in self.passes:
current_pass_name = ArmPass.get_name(current_pass)
for required_pass_name in ArmPass.get_required_passes(current_pass):
passes_to_run[required_pass_name].append(current_pass_name)

passes_to_run.pop(current_pass_name, None)

if len(passes_to_run) > 0:
error_msg = "The following constraints for passes are not met:\n"
for required_pass, requiring_passes in passes_to_run.items():
for requiring_pass in requiring_passes:
error_msg += (
f" - {required_pass} must run after {requiring_pass}\n"
)

raise RuntimeError(error_msg)

def _transform(self, graph_module: GraphModule):
with TosaLoweringContext(self.tosa_spec):
return self(graph_module).graph_module
Expand All @@ -125,7 +155,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeLinearVectorNormPass())
self.add_pass(
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
)
Expand Down Expand Up @@ -175,6 +204,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(RemoveNoopPass())
self.add_pass(InsertRescalePass())

self.validate_constraints_mandatory()
return self._transform(exported_program.graph_module)

def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
Expand Down Expand Up @@ -258,6 +288,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(RemoveNoopPass())
self.add_pass(InsertRescalePass())

self.validate_constraints_mandatory()
return self._transform(exported_program.graph_module)

def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
Expand Down
6 changes: 5 additions & 1 deletion backends/arm/_passes/broadcast_args_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

from executorch.backends.arm._passes import ArmPass

from executorch.backends.arm._passes.arm_pass_utils import (
Expand All @@ -12,7 +14,7 @@

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import PassResult
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node


Expand All @@ -22,6 +24,8 @@ class BroadcastArgsPass(ArmPass):
This is done when more than one arg needs broadcasting.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/cast_bool_to_int8_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input
# If input/output is bool lest add a cast/conversion pass before/after to/from int8.

from typing import Set, Type

import torch

from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -15,6 +17,8 @@
class CastBoolToInt8Pass(ExportPass):
"""Casts the input to int8 if it is not already and casts back the output to the original input dtype."""

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = {
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/cast_int64_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-unsafe

import logging
from typing import Set, Type

import torch
from executorch.exir.pass_base import ExportPass, PassResult
Expand All @@ -19,6 +20,8 @@ class CastInt64BuffersToInt32Pass(ExportPass):
Cast int64 buffers to int32 if the int64 data is in int32 range.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def __init__(self, exported_program: torch.export.ExportedProgram):
super(CastInt64BuffersToInt32Pass, self).__init__()
self.exported_program = exported_program
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/cast_to_int32_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch

from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -12,6 +14,8 @@
class CastToInt32Pass(ExportPass):
"""Casts the input to int32 if it is not already and casts back the output to the original input dtype."""

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = {
exir_ops.edge.aten.bitwise_left_shift.Tensor,
exir_ops.edge.aten.bitwise_right_shift.Tensor,
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# LICENSE file in the root directory of this source tree.


from typing import Set, Type

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -21,6 +23,8 @@ class Conv1dUnsqueezePass(ExportPass):
3) squeeze the output back down to 3d.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.convolution.default:
return super().call_operator(op, args, kwargs, meta)
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/convert_any_default_dim_dims_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch
from executorch.exir.dialects._ops import ( # type: ignore[import-not-found]
ops as exir_ops,
Expand Down Expand Up @@ -44,6 +46,8 @@ class ConvertAnyDefaultDimDimsPass(ExportPass):
squeeze(dim = [dim1, dim2])
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call(self, graph_module: torch.fx.GraphModule):
modified = False
for node in graph_module.graph.nodes:
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-unsafe

import logging
from typing import cast
from typing import cast, Set, Type

import torch

Expand Down Expand Up @@ -50,6 +50,8 @@ class ConvertExpandCopyToRepeatPass(ExportPass):
Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

expand_copy = exir_ops.edge.aten.expand_copy.default
repeat = exir_ops.edge.aten.repeat.default

Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/convert_full_like_to_full_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -19,6 +21,8 @@ class ConvertFullLikeToFullPass(ExportPass):
Skip layout and device since it's not relevant for our backend.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call_operator(self, op, args, kwargs, meta):
if op not in [
exir_ops.edge.aten.full_like.default,
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/convert_int64_const_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


import logging
from typing import Set, Type

import torch
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
Expand All @@ -30,6 +31,8 @@ class ConvertInt64ConstOpsToInt32Pass(ExportPass):
5. `torch.tensor`
"""

_passes_required_after: Set[Type[ExportPass]] = set()

torch_ops = [
torch.ops.aten.full.default,
torch.ops.aten.arange.default,
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/convert_int64_output_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


import logging
from typing import Set, Type

import torch
from executorch.backends.arm._passes.arm_pass_utils import (
Expand Down Expand Up @@ -44,6 +45,8 @@ class ConvertInt64OutputOpsToInt32Pass(ExportPass):
the int32 range.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

aten_cast_ops = (
torch.ops.aten.to.dtype,
torch.ops.aten.to.dtype_layout,
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/convert_int_pow_to_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

# pyre-unsafe

from typing import Set, Type

from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class ConvertIntPowToMuls(ArmPass):
Expand All @@ -16,6 +19,8 @@ class ConvertIntPowToMuls(ArmPass):
Needs to be run before doing scalar to tensor conversion.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.pow.Tensor_Scalar:
return super().call_operator(op, args, kwargs, meta)
Expand Down
Loading
Loading