From df3468fd4b67bfc479108775fd419e2da230132f Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Fri, 18 Jul 2025 17:29:55 +0200 Subject: [PATCH 1/2] Arm backend: Add pass order validation to ArmPassManager MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a mechanism to enforce required ordering of passes in ArmPassManager. Each ArmPass must now declare which passes are required to run after it, ensuring ordering constraints are always upheld. This prevents accidental breakage when modifying pass ordering in the manager. Ordering constraints are verified by the new method ArmPass.validate_constraints_mandatory. We considered reusing torch.fx.passes.infra.pass_manager.PassManager.validate_constraints, but that utility only checks pairwise ordering and cannot enforce that a pass is actually run, which did not meet our needs. This patch only implements the mechanism and tests for it. Defining the actual pass orderings are done in a later patch. Change-Id: I6f822ec4192b0c8dd19b70d85905adcd08ca502f Signed-off-by: Adrian Lundell Co-authored-by: Martin Lindström --- backends/arm/_passes/add_bias_pass.py | 6 +- .../arm/_passes/annotate_decomposed_matmul.py | 4 +- .../_passes/annotate_output_dim_order_pass.py | 7 +- backends/arm/_passes/arm_pass.py | 33 ++++++- backends/arm/_passes/arm_pass_manager.py | 32 +++++++ backends/arm/_passes/broadcast_args_pass.py | 6 +- .../arm/_passes/cast_bool_to_int8_pass.py | 4 + backends/arm/_passes/cast_int64_pass.py | 3 + backends/arm/_passes/cast_to_int32_pass.py | 4 + backends/arm/_passes/conv1d_unsqueeze_pass.py | 4 + .../convert_any_default_dim_dims_pass.py | 4 + .../_passes/convert_expand_copy_to_repeat.py | 4 +- .../_passes/convert_full_like_to_full_pass.py | 4 + .../convert_int64_const_ops_to_int32.py | 3 + .../convert_int64_output_ops_to_int32.py | 3 + .../arm/_passes/convert_int_pow_to_mul.py | 5 + backends/arm/_passes/convert_minmax_pass.py | 4 + .../arm/_passes/convert_split_to_slice.py | 4 + .../arm/_passes/convert_squeezes_to_view.py | 4 + backends/arm/_passes/convert_to_clamp.py | 4 +- backends/arm/_passes/decompose_acosh_pass.py | 5 + .../decompose_adaptive_avg_pool2d_pass.py | 4 + backends/arm/_passes/decompose_addmm_pass.py | 5 + .../_passes/decompose_asin_and_acos_pass.py | 4 + backends/arm/_passes/decompose_asinh_pass.py | 5 + backends/arm/_passes/decompose_atan_pass.py | 4 + backends/arm/_passes/decompose_atanh_pass.py | 5 + backends/arm/_passes/decompose_avg_pool2d.py | 4 +- .../_passes/decompose_batch_norm_no_stats.py | 5 +- backends/arm/_passes/decompose_cosh_pass.py | 5 + .../decompose_cosine_similarity_pass.py | 4 + backends/arm/_passes/decompose_cumsum_pass.py | 5 +- backends/arm/_passes/decompose_div_pass.py | 6 +- backends/arm/_passes/decompose_elu_pass.py | 5 + .../arm/_passes/decompose_embedding_pass.py | 3 + backends/arm/_passes/decompose_expm1_pass.py | 5 + backends/arm/_passes/decompose_gelu_pass.py | 4 + backends/arm/_passes/decompose_glu_pass.py | 5 + .../arm/_passes/decompose_grouped_conv.py | 3 + .../arm/_passes/decompose_groupnorm_pass.py | 5 +- .../arm/_passes/decompose_layernorm_pass.py | 5 +- .../arm/_passes/decompose_leaky_relu_pass.py | 5 + .../decompose_linalg_vector_norm_pass.py | 4 + backends/arm/_passes/decompose_linear_pass.py | 6 +- backends/arm/_passes/decompose_logit_pass.py | 5 + backends/arm/_passes/decompose_masked_fill.py | 5 + .../decompose_maxpool2d_with_dilation.py | 4 + .../arm/_passes/decompose_meandim_pass.py | 4 + backends/arm/_passes/decompose_ne_pass.py | 5 + backends/arm/_passes/decompose_round_pass.py | 5 + backends/arm/_passes/decompose_select.py | 4 + backends/arm/_passes/decompose_sign_pass.py | 5 + backends/arm/_passes/decompose_silu_pass.py | 4 + backends/arm/_passes/decompose_sinh_pass.py | 5 + .../arm/_passes/decompose_softmax_pass.py | 4 + .../decompose_softmax_unstable_pass.py | 5 + backends/arm/_passes/decompose_sqrt_pass.py | 3 +- backends/arm/_passes/decompose_sum_pass.py | 4 + backends/arm/_passes/decompose_var_pass.py | 5 + .../decorate_fp32_to_int32_casting_pass.py | 5 + .../fold_qdq_with_annotated_qparams_pass.py | 8 +- backends/arm/_passes/fuse_batchnorm2d_pass.py | 4 + .../arm/_passes/fuse_constant_ops_pass.py | 5 + .../_passes/fuse_equal_placeholders_pass.py | 3 + .../_passes/fuse_quantized_activation_pass.py | 4 + backends/arm/_passes/insert_rescales_pass.py | 4 +- backends/arm/_passes/insert_table_ops.py | 4 +- backends/arm/_passes/match_arg_dtype_pass.py | 4 + backends/arm/_passes/match_arg_ranks_pass.py | 4 +- backends/arm/_passes/mm_to_bmm_pass.py | 4 + backends/arm/_passes/remove_noop_pass.py | 3 + .../arm/_passes/replace_inf_values_pass.py | 4 + .../replace_scalar_with_tensor_pass.py | 7 +- .../arm/_passes/scalars_to_attribute_pass.py | 4 +- .../arm/_passes/size_adjust_input_pass.py | 4 +- .../arm/_passes/to_tosa_memory_format_pass.py | 9 ++ .../_passes/unsqueeze_before_repeat_pass.py | 6 +- .../unsqueeze_scalar_placeholders_pass.py | 4 + .../arm/test/misc/test_pass_required_order.py | 95 +++++++++++++++++++ backends/transforms/decompose_sdpa.py | 3 + backends/transforms/fuse_view_copy.py | 4 + 81 files changed, 489 insertions(+), 23 deletions(-) create mode 100644 backends/arm/test/misc/test_pass_required_order.py diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py index 31c0c0505cb..a8a76c0a47b 100644 --- a/backends/arm/_passes/add_bias_pass.py +++ b/backends/arm/_passes/add_bias_pass.py @@ -3,13 +3,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.transforms.utils import create_constant_placeholder from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from torch.export.graph_signature import InputKind @@ -19,6 +21,8 @@ class AddBiasPass(ArmPass): The bias is set to zero. """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = (exir_ops.edge.aten.convolution.default,) def call(self, graph_module): diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 8156ca0b89d..81b7b36cc0b 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -7,7 +7,7 @@ import itertools import operator -from typing import cast, List +from typing import cast, List, Set, Type import torch from executorch.backends.arm._passes.arm_pass_utils import create_node @@ -29,6 +29,8 @@ class AnnotateDecomposedMatmulPass(ExportPass): matmul-op (can be mm or bmm). """ + _passes_required_after: Set[Type[ExportPass]] = set() + def _match_partition_to_node( self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node] ) -> torch.fx.Node: diff --git a/backends/arm/_passes/annotate_output_dim_order_pass.py b/backends/arm/_passes/annotate_output_dim_order_pass.py index 08f93383a9c..8dc13326e4a 100644 --- a/backends/arm/_passes/annotate_output_dim_order_pass.py +++ b/backends/arm/_passes/annotate_output_dim_order_pass.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult class AnnotateOutputDimOrderPass(ArmPass): @@ -14,6 +17,8 @@ class AnnotateOutputDimOrderPass(ArmPass): for verifying that the dim order does not change unexpectedly in later passes. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module): output_node = graph_module.graph.output_node() output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module) diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index 085267a174e..c76b5d157a7 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -6,7 +6,8 @@ # pyre-unsafe import traceback -from typing import Optional +from abc import abstractmethod +from typing import List, Optional, Set, Type import torch from executorch.exir.pass_base import ExportPass, NodeMetadata @@ -19,6 +20,36 @@ def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = No super(ArmPass, self).__init__() self.exported_program = exported_program + @property + @abstractmethod + def _passes_required_after(self) -> Set[Type[ExportPass]]: + """The subclass defines passes that must run after it""" + pass + + @staticmethod + def get_required_passes(pass_) -> List[str]: + """ + Returns the list of passes that must be run after this pass, sorted by name. + """ + if hasattr(pass_, "_passes_required_after"): + return sorted([ArmPass.get_name(p) for p in pass_._passes_required_after]) + else: + return [] + + @staticmethod + def get_name(pass_) -> str: + """ + Returns the name of the pass. + """ + if isinstance(pass_, ExportPass): + return pass_.__class__.__name__ + elif hasattr(pass_, "__name__"): + return pass_.__name__ + else: + raise ValueError( + f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute." + ) + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): if not updated: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f49206da67e..0a42e51e526 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -7,6 +7,9 @@ # pyre-unsafe + +from collections import defaultdict + import executorch.backends.arm.tosa.dialect # noqa: unused from executorch.backends.arm._passes import ( AddBiasPass, @@ -94,6 +97,7 @@ UnsqueezeScalarPlaceholdersPass, ) +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, TosaSpecification, @@ -115,6 +119,32 @@ def __init__(self, tosa_spec: TosaSpecification) -> None: self.tosa_spec = tosa_spec super().__init__() + def validate_constraints_mandatory(self): + """ + Validates that necessary passes have run before transforming to backend. + + Note that this differs from the original validate_constraints function, which + only checks the order of passes. + """ + passes_to_run = defaultdict(list) + + for current_pass in self.passes: + current_pass_name = ArmPass.get_name(current_pass) + for required_pass_name in ArmPass.get_required_passes(current_pass): + passes_to_run[required_pass_name].append(current_pass_name) + + passes_to_run.pop(current_pass_name, None) + + if len(passes_to_run) > 0: + error_msg = "The following constraints for passes are not met:\n" + for required_pass, requiring_passes in passes_to_run.items(): + for requiring_pass in requiring_passes: + error_msg += ( + f" - {required_pass} must run after {requiring_pass}\n" + ) + + raise RuntimeError(error_msg) + def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module @@ -175,6 +205,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) + self.validate_constraints_mandatory() return self._transform(exported_program.graph_module) def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: @@ -258,6 +289,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) + self.validate_constraints_mandatory() return self._transform(exported_program.graph_module) def transform_to_backend_pipeline(self, exported_program: ExportedProgram): diff --git a/backends/arm/_passes/broadcast_args_pass.py b/backends/arm/_passes/broadcast_args_pass.py index f125ba13ff4..659e6aca686 100644 --- a/backends/arm/_passes/broadcast_args_pass.py +++ b/backends/arm/_passes/broadcast_args_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( @@ -12,7 +14,7 @@ from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node @@ -22,6 +24,8 @@ class BroadcastArgsPass(ArmPass): This is done when more than one arg needs broadcasting. """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.sub.Tensor, diff --git a/backends/arm/_passes/cast_bool_to_int8_pass.py b/backends/arm/_passes/cast_bool_to_int8_pass.py index 1352671b01e..771b6d9e174 100644 --- a/backends/arm/_passes/cast_bool_to_int8_pass.py +++ b/backends/arm/_passes/cast_bool_to_int8_pass.py @@ -6,6 +6,8 @@ # The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input # If input/output is bool lest add a cast/conversion pass before/after to/from int8. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -15,6 +17,8 @@ class CastBoolToInt8Pass(ExportPass): """Casts the input to int8 if it is not already and casts back the output to the original input dtype.""" + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Tensor, diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index 8052c8fd2ce..d7b2a6b6b43 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -6,6 +6,7 @@ # pyre-unsafe import logging +from typing import Set, Type import torch from executorch.exir.pass_base import ExportPass, PassResult @@ -19,6 +20,8 @@ class CastInt64BuffersToInt32Pass(ExportPass): Cast int64 buffers to int32 if the int64 data is in int32 range. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: torch.export.ExportedProgram): super(CastInt64BuffersToInt32Pass, self).__init__() self.exported_program = exported_program diff --git a/backends/arm/_passes/cast_to_int32_pass.py b/backends/arm/_passes/cast_to_int32_pass.py index c4b009e2b88..2e574568235 100644 --- a/backends/arm/_passes/cast_to_int32_pass.py +++ b/backends/arm/_passes/cast_to_int32_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -12,6 +14,8 @@ class CastToInt32Pass(ExportPass): """Casts the input to int32 if it is not already and casts back the output to the original input dtype.""" + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.bitwise_right_shift.Tensor, diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index 56f674e9066..718c94fc196 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -6,6 +6,8 @@ # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -21,6 +23,8 @@ class Conv1dUnsqueezePass(ExportPass): 3) squeeze the output back down to 3d. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.convolution.default: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_any_default_dim_dims_pass.py b/backends/arm/_passes/convert_any_default_dim_dims_pass.py index 7085f17add0..f4ec0c57b2a 100644 --- a/backends/arm/_passes/convert_any_default_dim_dims_pass.py +++ b/backends/arm/_passes/convert_any_default_dim_dims_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ( # type: ignore[import-not-found] ops as exir_ops, @@ -44,6 +46,8 @@ class ConvertAnyDefaultDimDimsPass(ExportPass): squeeze(dim = [dim1, dim2]) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index ee509c7ebb5..1c6b52b150a 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -6,7 +6,7 @@ # pyre-unsafe import logging -from typing import cast +from typing import cast, Set, Type import torch @@ -50,6 +50,8 @@ class ConvertExpandCopyToRepeatPass(ExportPass): Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions. """ + _passes_required_after: Set[Type[ExportPass]] = set() + expand_copy = exir_ops.edge.aten.expand_copy.default repeat = exir_ops.edge.aten.repeat.default diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index 234e2ecda82..2f46e19005a 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -19,6 +21,8 @@ class ConvertFullLikeToFullPass(ExportPass): Skip layout and device since it's not relevant for our backend. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.full_like.default, diff --git a/backends/arm/_passes/convert_int64_const_ops_to_int32.py b/backends/arm/_passes/convert_int64_const_ops_to_int32.py index 704c89dbd78..9af44f56f11 100644 --- a/backends/arm/_passes/convert_int64_const_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_const_ops_to_int32.py @@ -7,6 +7,7 @@ import logging +from typing import Set, Type import torch from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT @@ -30,6 +31,8 @@ class ConvertInt64ConstOpsToInt32Pass(ExportPass): 5. `torch.tensor` """ + _passes_required_after: Set[Type[ExportPass]] = set() + torch_ops = [ torch.ops.aten.full.default, torch.ops.aten.arange.default, diff --git a/backends/arm/_passes/convert_int64_output_ops_to_int32.py b/backends/arm/_passes/convert_int64_output_ops_to_int32.py index 788201be6c8..d0d29d14e30 100644 --- a/backends/arm/_passes/convert_int64_output_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_output_ops_to_int32.py @@ -7,6 +7,7 @@ import logging +from typing import Set, Type import torch from executorch.backends.arm._passes.arm_pass_utils import ( @@ -44,6 +45,8 @@ class ConvertInt64OutputOpsToInt32Pass(ExportPass): the int32 range. """ + _passes_required_after: Set[Type[ExportPass]] = set() + aten_cast_ops = ( torch.ops.aten.to.dtype, torch.ops.aten.to.dtype_layout, diff --git a/backends/arm/_passes/convert_int_pow_to_mul.py b/backends/arm/_passes/convert_int_pow_to_mul.py index f22a2fd0b3c..8f9b3a9cb4b 100644 --- a/backends/arm/_passes/convert_int_pow_to_mul.py +++ b/backends/arm/_passes/convert_int_pow_to_mul.py @@ -5,8 +5,11 @@ # pyre-unsafe +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass class ConvertIntPowToMuls(ArmPass): @@ -16,6 +19,8 @@ class ConvertIntPowToMuls(ArmPass): Needs to be run before doing scalar to tensor conversion. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.pow.Tensor_Scalar: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_minmax_pass.py b/backends/arm/_passes/convert_minmax_pass.py index 9f409632c20..2cf59ab2300 100644 --- a/backends/arm/_passes/convert_minmax_pass.py +++ b/backends/arm/_passes/convert_minmax_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -29,6 +31,8 @@ class ConvertMinMaxPass(ExportPass): squeeze(dim = [dim1, dim2]) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def check_argmax(self, node): """ Raises a RuntimeError if the argmax value returned by the min/max op is used in the graph. diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 67bd9d73e81..7578c07ca53 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch.fx from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -19,6 +21,8 @@ class ConvertSplitToSlicePass(ExportPass): Replace a split operation with many slice operations. """ + _passes_required_after: Set[Type[ExportPass]] = set() + split_ops = ( exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.split_copy.Tensor, diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index 889dbe74172..9c5d26a7c22 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -6,6 +6,8 @@ # pyre-unsafe +from typing import Set, Type + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -15,6 +17,8 @@ class ConvertSqueezesToViewPass(ExportPass): Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.squeeze_copy.dims, diff --git a/backends/arm/_passes/convert_to_clamp.py b/backends/arm/_passes/convert_to_clamp.py index 8f2c9b16f9a..3f8cac30b96 100644 --- a/backends/arm/_passes/convert_to_clamp.py +++ b/backends/arm/_passes/convert_to_clamp.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Set, Tuple, Type from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -24,6 +24,8 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]: class ConvertToClampPass(ExportPass): + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in edge_operators: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 1d92dd68c4a..30c5c137482 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -5,8 +5,11 @@ # pyre-unsafe +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_acosh_op = exir_ops.edge.aten.acosh.default @@ -19,6 +22,8 @@ class DecomposeAcoshPass(ArmPass): acosh(x) = log(x + sqrt((x-1)(x+1)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_acosh_op: diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index abfcc8e3945..f1623b4aca7 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -4,12 +4,14 @@ # LICENSE file in the root directory of this source tree. from math import ceil, floor +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ops = (exir_ops.edge.aten._adaptive_avg_pool2d.default,) aten_ops = (torch.ops.aten.adaptive_avg_pool2d.default,) @@ -41,6 +43,8 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass): The output is of size output_size_h x output_size_w for any input. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops): return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_addmm_pass.py b/backends/arm/_passes/decompose_addmm_pass.py index b59a8cb02d3..142f3143f38 100644 --- a/backends/arm/_passes/decompose_addmm_pass.py +++ b/backends/arm/_passes/decompose_addmm_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case @@ -36,6 +39,8 @@ def get_ops(op): class DecomposeAddmmPass(ArmPass): """Decomposes the addmm operator into tensor multiplication and addition.""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [edge_addmm, aten_addmm]: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index e067f17b0ca..c083cc669c2 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -7,11 +7,13 @@ import logging from math import pi +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_asin_op = (exir_ops.edge.aten.asin.default,) @@ -54,6 +56,8 @@ class DecomposeAsinAndAcosPass(ArmPass): """ + _passes_required_after: Set[Type[ExportPass]] = set() + def _build_polynomial( self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str] ) -> torch.Tensor: diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index a0b78c51a77..b8f7300beb5 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -6,8 +6,11 @@ # pyre-unsafe +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_asinh_op = (exir_ops.edge.aten.asinh.default,) @@ -20,6 +23,8 @@ class DecomposeAsinhPass(ArmPass): asinh(x) = log(x + sqrt(x^2 + 1)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in edge_asinh_op: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index 57b9dde5216..7faef26a245 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -5,9 +5,11 @@ import logging from math import pi +from typing import Set, Type from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_atan = exir_ops.edge.aten.atan.default # MI case @@ -35,6 +37,8 @@ def _get_atan_ops(op): class DecomposeAtanPass(ArmPass): """Decomposes the atan operator into a rational (Padé) approximation.""" + _passes_required_after: Set[Type[ExportPass]] = set() + def _rational_approximation(self, z, ops, meta): """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index dfdad41e556..d06598923b3 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -3,8 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_atanh = exir_ops.edge.aten.atanh.default # MI case @@ -30,6 +33,8 @@ class DecomposeAtanhPass(ArmPass): atanh(x) = 0.5 * log((1 + x) / (1 - x)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op is not edge_atanh: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_avg_pool2d.py b/backends/arm/_passes/decompose_avg_pool2d.py index 21ed6b518c7..0240661053b 100644 --- a/backends/arm/_passes/decompose_avg_pool2d.py +++ b/backends/arm/_passes/decompose_avg_pool2d.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, @@ -34,7 +36,7 @@ def get_decomposition(op) -> tuple: class DecomposeAvgPool2d(ExportPass): - """ """ + _passes_required_after: Set[Type[ExportPass]] = set() def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops): diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py index 5fdb8db2d7c..82937241369 100644 --- a/backends/arm/_passes/decompose_batch_norm_no_stats.py +++ b/backends/arm/_passes/decompose_batch_norm_no_stats.py @@ -6,12 +6,13 @@ # pyre-unsafe import operator +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult class DecomposeBatchNormNoStatsPass(ArmPass): @@ -33,6 +34,8 @@ class DecomposeBatchNormNoStatsPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 bn_ops = ( exir_ops.edge.aten._native_batch_norm_legit.no_stats, diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index a94cf9ecff0..b71ca388651 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -3,8 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_cosh = exir_ops.edge.aten.cosh.default @@ -19,6 +22,8 @@ class DecomposeCoshPass(ArmPass): """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_cosh: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_cosine_similarity_pass.py b/backends/arm/_passes/decompose_cosine_similarity_pass.py index 9978e653408..e2ab01b345f 100644 --- a/backends/arm/_passes/decompose_cosine_similarity_pass.py +++ b/backends/arm/_passes/decompose_cosine_similarity_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass @@ -22,6 +24,8 @@ class DecomposeCosineSimilarityPass(ExportPass): out = div(dot, denom) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in torch_cosine_similarity: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_cumsum_pass.py b/backends/arm/_passes/decompose_cumsum_pass.py index 155ccd11594..04e6275c6c1 100644 --- a/backends/arm/_passes/decompose_cumsum_pass.py +++ b/backends/arm/_passes/decompose_cumsum_pass.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from math import prod +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass @@ -12,7 +13,7 @@ from executorch.backends.transforms.utils import create_constant_placeholder from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from torch.export.graph_signature import InputKind @@ -39,6 +40,8 @@ class DecomposeCumsumPass(ArmPass): And the convolution is applied over dimension H. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module): graph = graph_module.graph targets = (exir_ops.edge.aten.cumsum.default, torch.ops.aten.cumsum.default) diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index 893531dac69..b6e289ff049 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -6,6 +6,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -37,6 +39,8 @@ class DecomposeDivPass(ExportPass): y = mul(a,x) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_elu_pass.py b/backends/arm/_passes/decompose_elu_pass.py index 743f1b46f4d..ba3d32b7529 100644 --- a/backends/arm/_passes/decompose_elu_pass.py +++ b/backends/arm/_passes/decompose_elu_pass.py @@ -3,8 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_elu_ops = (exir_ops.edge.aten.elu.default,) @@ -55,6 +58,8 @@ class DecomposeEluPass(ArmPass): - exir_ops.edge.aten.mul.Scalar """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in edge_elu_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_embedding_pass.py b/backends/arm/_passes/decompose_embedding_pass.py index 6de971f402f..5b2ad27eaf6 100644 --- a/backends/arm/_passes/decompose_embedding_pass.py +++ b/backends/arm/_passes/decompose_embedding_pass.py @@ -8,6 +8,7 @@ import logging from math import prod +from typing import Set, Type import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -33,6 +34,8 @@ class DecomposeEmbeddingPass(ExportPass): i = indices is expected to be int32 before this pass """ + _passes_required_after: Set[Type[ExportPass]] = set() + aten_ops = (torch.ops.aten.embedding.default,) edge_ops = (exir_ops.edge.aten.embedding.default,) diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 5b1b90495b5..21d3c975de3 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -3,8 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_expm1_ops = (exir_ops.edge.aten.expm1.default,) # MI case @@ -68,6 +71,8 @@ class DecomposeExpm1Pass(ArmPass): - exir_ops.edge.aten.logical_and.default """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in edge_expm1_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 6e72175e68b..ef6a4753b8c 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops @@ -77,6 +79,8 @@ class DecomposeGeluPass(ExportPass): %op7 = mul(%op6, %FULL_0_5) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in torch_gelu + edge_gelu: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_glu_pass.py b/backends/arm/_passes/decompose_glu_pass.py index 183dc89cf61..6b53609c951 100644 --- a/backends/arm/_passes/decompose_glu_pass.py +++ b/backends/arm/_passes/decompose_glu_pass.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For FP case @@ -36,6 +39,8 @@ def get_ops(op): class DecomposeGluPass(ArmPass): """Decomposes the GLU operator into hadamard product and sigmoid.""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [edge_glu, aten_glu]: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_grouped_conv.py b/backends/arm/_passes/decompose_grouped_conv.py index ce9fe9c9937..2f0d7b4d72c 100644 --- a/backends/arm/_passes/decompose_grouped_conv.py +++ b/backends/arm/_passes/decompose_grouped_conv.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from copy import copy +from typing import Set, Type import torch from executorch.backends.arm._passes.quant_args import QuantArgs @@ -33,6 +34,8 @@ class DecomposeGroupedConv(ExportPass): x = cat(x1, x2) """ + _passes_required_after: Set[Type[ExportPass]] = set() + @staticmethod def _get_decomposition(op): match op: diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index c6cb1b05e40..7f0d7fdeafd 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -6,12 +6,13 @@ # pyre-unsafe import operator +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult def get_group_norm_decomposition(op) -> tuple: @@ -57,6 +58,8 @@ class DecomposeGroupNormPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index e6cbdfb91a0..0710ed37b45 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -6,12 +6,13 @@ # pyre-unsafe import operator +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult def get_layer_norm_decomposition(op) -> tuple: @@ -56,6 +57,8 @@ class DecomposeLayerNormPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: if node.op != "call_function" or node.target not in ( diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index e896cc584be..8ae13a76eb0 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -6,9 +6,12 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ops = (exir_ops.edge.aten.leaky_relu.default,) torch_ops = (torch.ops.aten.leaky_relu.default,) @@ -46,6 +49,8 @@ class DecomposeLeakyReLUPass(ArmPass): %op5 = add(%op1,%op4) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_ops + torch_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 9f036c0524f..17441981654 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass @@ -28,6 +30,8 @@ class DecomposeLinearVectorNormPass(ExportPass): dtype prior, but we dont know this from FX graph. """ + _passes_required_after: Set[Type[ExportPass]] = set() + torch_linalg_vector_norm = (torch.ops.aten.linalg_vector_norm.default,) def call_operator(self, op, args, kwargs, meta): diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index 3d154d9b81e..70268c77a1d 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import numpy as np from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( @@ -12,7 +14,7 @@ get_first_fake_tensor, ) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult class DecomposeLinearPass(ArmPass): @@ -25,6 +27,8 @@ class DecomposeLinearPass(ArmPass): output = view(conv2d) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module): for node in graph_module.graph.nodes: if node.op != "call_function": diff --git a/backends/arm/_passes/decompose_logit_pass.py b/backends/arm/_passes/decompose_logit_pass.py index 40e2b22cb54..a82650f0b9e 100644 --- a/backends/arm/_passes/decompose_logit_pass.py +++ b/backends/arm/_passes/decompose_logit_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For FP case @@ -60,6 +63,8 @@ class DecomposeLogitPass(ArmPass): log(y * reciprocal((-1) * y + 1)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [edge_logit, aten_logit]: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_masked_fill.py b/backends/arm/_passes/decompose_masked_fill.py index fbf3079c92b..ced58aa3920 100644 --- a/backends/arm/_passes/decompose_masked_fill.py +++ b/backends/arm/_passes/decompose_masked_fill.py @@ -6,10 +6,13 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ops = (exir_ops.edge.aten.masked_fill.Scalar,) @@ -37,6 +40,8 @@ class DecomposeMaskedFill(ArmPass): Decomposed to a where and a full_like operator. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops): return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py index ff6db260099..1df062ddb57 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py @@ -6,9 +6,11 @@ # pyre-unsafe import operator +from typing import Set, Type from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # We'll decompose only the EXIR edge max_pool2d ops when dilation > 1 EDGE_MAXPOOL2D = ( @@ -22,6 +24,8 @@ class DecomposeMaxPool2DPass(ArmPass): Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): # Only intercept EXIR edge max_pool2d ops if op not in EDGE_MAXPOOL2D: diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index a78514b6af5..716924dfbf2 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -5,12 +5,14 @@ from copy import copy from math import prod +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass def get_meandim_decomposition(op) -> tuple: @@ -62,6 +64,8 @@ class DecomposeMeanDimPass(ArmPass): x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, graph_module, tosa_spec): super().__init__() self._graph_module = graph_module diff --git a/backends/arm/_passes/decompose_ne_pass.py b/backends/arm/_passes/decompose_ne_pass.py index 16443d5d2fb..3bd4f4540bb 100644 --- a/backends/arm/_passes/decompose_ne_pass.py +++ b/backends/arm/_passes/decompose_ne_pass.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ne_ops = (exir_ops.edge.aten.ne.Tensor,) aten_ne_ops = (torch.ops.aten.ne.Tensor, torch.ops.aten.ne_.Tensor) @@ -53,6 +56,8 @@ class DecomposeNotEqualPass(ArmPass): - followed by aten.logical_not.default or its edge equivalent """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_ne_ops + aten_ne_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_round_pass.py b/backends/arm/_passes/decompose_round_pass.py index edfa3817064..35d36e80396 100644 --- a/backends/arm/_passes/decompose_round_pass.py +++ b/backends/arm/_passes/decompose_round_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass from torch._ops import OpOverload @@ -56,6 +59,8 @@ class DecomposeRoundPass(ArmPass): %result = where(%is_non_negative, %floor, %ceil) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (exir_ops.edge.aten.round.default, torch.ops.aten.round.default): return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_select.py b/backends/arm/_passes/decompose_select.py index 99c89f474ea..9c65cd1c0a8 100644 --- a/backends/arm/_passes/decompose_select.py +++ b/backends/arm/_passes/decompose_select.py @@ -6,6 +6,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -20,6 +22,8 @@ class DecomposeSelectPass(ExportPass): This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/decompose_sign_pass.py b/backends/arm/_passes/decompose_sign_pass.py index 1038ff0f3fa..c4cb964316d 100644 --- a/backends/arm/_passes/decompose_sign_pass.py +++ b/backends/arm/_passes/decompose_sign_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case @@ -42,6 +45,8 @@ def get_ops(op): class DecomposeSignPass(ArmPass): """Decomposes the sign operator into a sequence of operations that are supported by the Arm backend.""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_sign, aten_sign): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_silu_pass.py b/backends/arm/_passes/decompose_silu_pass.py index 68ebb3f4515..cb7b55be520 100644 --- a/backends/arm/_passes/decompose_silu_pass.py +++ b/backends/arm/_passes/decompose_silu_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass @@ -22,6 +24,8 @@ class DecomposeSiluPass(ExportPass): y = mul(a,x) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (aten_silu_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index 7192eb9bf74..473a263e9a5 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -4,8 +4,11 @@ # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case @@ -24,6 +27,8 @@ class DecomposeSinhPass(ArmPass): and scalar multiplication. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op is not edge_sinh: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_softmax_pass.py b/backends/arm/_passes/decompose_softmax_pass.py index a735501f711..47f448ae851 100644 --- a/backends/arm/_passes/decompose_softmax_pass.py +++ b/backends/arm/_passes/decompose_softmax_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -62,6 +64,8 @@ class DecomposeSoftmaxPass(ExportPass): (in logsoftmax case: %op7 = log(%op6)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in torch_softmax + edge_softmax: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py index b6f5e11b66b..5e704585eb0 100644 --- a/backends/arm/_passes/decompose_softmax_unstable_pass.py +++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py @@ -5,9 +5,12 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For BI case torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) @@ -57,6 +60,8 @@ class DecomposeSoftmaxUnstablePass(ArmPass): (in logsoftmax case: %op5 = log(%op4)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in torch_softmax + edge_softmax: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index 547d0091e90..c93686901d5 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Tuple, Union +from typing import Set, Tuple, Type, Union import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -27,6 +27,7 @@ def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]: class DecomposeSqrtPass(ExportPass): + _passes_required_after: Set[Type[ExportPass]] = set() def call_operator(self, op, args, kwargs, meta): """ diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 52b9c10c49f..16027ccec2b 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -40,6 +42,8 @@ class DecomposeSumPass(ExportPass): view(shape = squeezed_shape) -> squeezed_shape """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.sum.dim_IntList, diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 15872738f3e..f8396da0420 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -7,10 +7,13 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass def get_var_decomposition(op) -> tuple: @@ -47,6 +50,8 @@ class DecomposeVarPass(ArmPass): y = div(sum, max(0, N-correction)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in ( exir_ops.edge.aten.var.correction, diff --git a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py index 17a682c0a8e..9d704520302 100644 --- a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py +++ b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py @@ -6,10 +6,13 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass def _get_decorated_ops(op): @@ -40,6 +43,8 @@ class DecorateFp32toInt32CastingPass(ArmPass): output = to_dim_order_copy(decorated_x, dtype=torch.int32) """ + _passes_required_after: Set[Type[ExportPass]] = set() + targets = [ exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 491b404f0a4..714543d3908 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -8,7 +8,7 @@ import copy -from typing import cast, Dict, Set, Tuple +from typing import cast, Dict, Set, Tuple, Type from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( @@ -100,6 +100,8 @@ class FoldAndAnnotateQParamsPass(ArmPass): """ + _passes_required_after: Set[Type[ExportPass]] = set() + def fold_and_annotate_arg( self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int ) -> None: @@ -210,6 +212,8 @@ class QuantizeOperatorArguments(ExportPass): - Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: GraphModule) -> PassResult: modified = False # Loop over the graph nodes and find full.default nodes. @@ -257,6 +261,8 @@ class RetraceFoldedDtypesPass(ExportPass): the output type of that matches the type of the output_qparams. """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops: Set[EdgeOpOverload] = { exir_ops.edge.aten.sum.dim_IntList, } diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py index 2dbdfa84cec..be884585d4d 100644 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ b/backends/arm/_passes/fuse_batchnorm2d_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -28,6 +30,8 @@ class FuseBatchnorm2DPass(ExportPass): the weights and bias of the convolution and removing the batchnorm. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram): self.exported_program = exported_program super().__init__() diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index f49565e3c38..07f3a4af245 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging +from typing import Set, Type import torch._export.utils import torch.fx @@ -41,6 +42,8 @@ def f(): return x """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program @@ -168,6 +171,8 @@ def f(node_name_pre_computed): return node_name_pre_computed """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = [ exir_ops.edge.aten.full.default, exir_ops.edge.aten.arange.start_step, diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index 5631e2f32e9..cf1177a0448 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -5,6 +5,7 @@ import hashlib from collections import defaultdict +from typing import Set, Type import torch from executorch.backends.arm._passes.arm_pass_utils import ( @@ -27,6 +28,8 @@ class FuseEqualPlaceholdersPass(ExportPass): with multiple users, using a cache for faster comparison. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram): self.exported_program = exported_program super().__init__() diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index 46a7d7f6f98..d39d7135f9c 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import Q_OPS @@ -14,6 +16,8 @@ class FuseQuantizedActivationPass(ExportPass): + _passes_required_after: Set[Type[ExportPass]] = set() + @staticmethod def _is_fuseable_quantized_activation(node: Node): """Fuse activations that have a 0 lower bound and quantized with a qmin zero-point""" diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 7f75aecf24c..100ac03c2b0 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from copy import copy -from typing import cast +from typing import cast, Set, Type from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs @@ -24,6 +24,8 @@ class InsertRescalePass(ExportPass): in the fake implementation of. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule): dq_args = QuantArgs.from_operator(node.target, node.args) q_args = QuantArgs.from_operator(user.target, user.args) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index fb5d7de5e12..d838ddc823d 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -6,7 +6,7 @@ # pyre-unsafe from itertools import chain -from typing import Callable, cast, Dict, Iterator, Set +from typing import Callable, cast, Dict, Iterator, Set, Type import torch from executorch.backends.arm._passes.arm_pass_utils import create_node @@ -117,6 +117,8 @@ class InsertTableOpsPass(ExportPass): which will be used to produce the table values in operators/op_table.py. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program diff --git a/backends/arm/_passes/match_arg_dtype_pass.py b/backends/arm/_passes/match_arg_dtype_pass.py index e7bf3b2d60e..d482614b03f 100644 --- a/backends/arm/_passes/match_arg_dtype_pass.py +++ b/backends/arm/_passes/match_arg_dtype_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import create_node, get_node_arg from executorch.exir.dialects._ops import ops as exir_ops @@ -38,6 +40,8 @@ class MatchArgDtypePass(ExportPass): """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.where.self} def call(self, graph_module: torch.fx.GraphModule): diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index d6cdfacb612..c411f3b8083 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -7,7 +7,7 @@ # pyre-unsafe -from typing import cast +from typing import cast, Set, Type from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -36,6 +36,8 @@ class MatchArgRanksPass(ExportPass): input2 = shape(1, 3, 1) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program): super().__init__() self.exported_program = exported_program diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 69d8573013e..6be0b9e2ac4 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -6,6 +6,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -28,6 +30,8 @@ class ConvertMmToBmmPass(ExportPass): 3) Squeeze output tensor to rank 2. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified_graph = False graph = graph_module.graph diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index 623517aac59..55c4f71f0a8 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -7,6 +7,7 @@ # pyre-unsafe import logging +from typing import Set, Type from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -17,6 +18,8 @@ class RemoveNoopPass(ExportPass): """Remove no-ops from graph_module""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in ( exir_ops.edge.dim_order_ops._clone_dim_order.default, diff --git a/backends/arm/_passes/replace_inf_values_pass.py b/backends/arm/_passes/replace_inf_values_pass.py index 8c721eda3d8..506030d82d7 100644 --- a/backends/arm/_passes/replace_inf_values_pass.py +++ b/backends/arm/_passes/replace_inf_values_pass.py @@ -7,6 +7,8 @@ # This pass is based on backends/qualcomm/_passes/replace_inf_values.py # with some modification to replaced inf values. +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass, PassResult @@ -16,6 +18,8 @@ class ReplaceInfValues(ExportPass): Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self): super(ReplaceInfValues, self).__init__() diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index 249eb9ffd41..f6ef056f677 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import Dict, Union +from typing import Dict, Set, Type, Union import torch from executorch.backends.transforms.replace_scalar_with_tensor import ( @@ -15,6 +15,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass # Operators that are included for both TOSA profiles @@ -56,6 +57,8 @@ class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass): + _passes_required_after: Set[Type[ExportPass]] = set() + scalar_to_tensor_ops = _common_ops | { exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor, @@ -66,6 +69,8 @@ def __init__(self): class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass): + _passes_required_after: Set[Type[ExportPass]] = set() + scalar_to_tensor_ops = _common_ops def __init__(self): diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 89468bff1ff..bb2a02cc679 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import cast, Union +from typing import cast, Set, Type, Union import torch from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor @@ -22,6 +22,8 @@ class ScalarsToAttributePass(ExportPass): to attribute Nodes that output the same value. """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = [ torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor, diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index e87d65c450f..5eb77dc56df 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -5,7 +5,7 @@ # pyre-unsafe -from typing import cast, TypeAlias +from typing import cast, Set, Type, TypeAlias import torch.fx from executorch.backends.arm._passes.arm_pass_utils import create_node @@ -185,6 +185,8 @@ class SizeAdjustInputPass(ExportPass): input. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph modified_graph = False diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index ac16cbaf8cb..dcbdfb03f7b 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -7,6 +7,7 @@ import logging +from typing import Set, Type import torch from executorch.backends.arm._passes.annotate_decomposed_matmul import ( @@ -48,6 +49,14 @@ class ToTosaMemoryFormatPass(ExportPass): The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape. """ + _passes_required_after: Set[Type[ExportPass]] = set() + + NHWC_order = (0, 2, 3, 1) + NHWC_inverse_order = (0, 3, 1, 2) + HWCM_order = (2, 3, 0, 1) + NNHWC_order = (0, 1, 3, 4, 2) + NNHWC_inverse_order = (0, 1, 4, 2, 3) + def __init__(self, exported_program: ExportedProgram) -> None: self.exported_program = exported_program super().__init__() diff --git a/backends/arm/_passes/unsqueeze_before_repeat_pass.py b/backends/arm/_passes/unsqueeze_before_repeat_pass.py index 01983baa9ab..66286b6a954 100644 --- a/backends/arm/_passes/unsqueeze_before_repeat_pass.py +++ b/backends/arm/_passes/unsqueeze_before_repeat_pass.py @@ -1,9 +1,11 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe +from typing import Set, Type + import torch import torch.fx from executorch.backends.arm._passes.arm_pass_utils import ( @@ -29,6 +31,8 @@ class UnsqueezeBeforeRepeatPass(ExportPass): repeat(multiples) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified_graph = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py index ccae9b503cf..d3932dd1217 100644 --- a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py +++ b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import is_buffer, is_param @@ -16,6 +18,8 @@ class UnsqueezeScalarPlaceholdersPass(ExportPass): This pass unsqueezes the placeholders to make sure shape is at least (1,). """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program): self.exported_program = exported_program super().__init__() diff --git a/backends/arm/test/misc/test_pass_required_order.py b/backends/arm/test/misc/test_pass_required_order.py new file mode 100644 index 00000000000..2745d25a498 --- /dev/null +++ b/backends/arm/test/misc/test_pass_required_order.py @@ -0,0 +1,95 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from typing import List, Set, Type + +import pytest +from executorch.backends.arm._passes.arm_pass_manager import ArmPass, ArmPassManager +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.exir.pass_base import ExportPass + + +class PassC(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = set() + + +class PassB(ArmPass): + _passes_required_after = {PassC} + + +class PassA(ArmPass): + _passes_required_after = {PassB, PassC} + + +class IndependentPass(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = set() + + +def _setup_pass_manager(passes: List[ArmPass] | None = None): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.00+INT") + pass_manager = ArmPassManager(tosa_spec) + if passes is not None: + for p in passes: + pass_manager.add_pass(p) + return pass_manager + + +def test_no_passes(): + pass_manager = _setup_pass_manager() + pass_manager.validate_constraints_mandatory() + + +def test_correct_order(): + pass_manager = _setup_pass_manager([PassA(), PassB(), PassC()]) + pass_manager.validate_constraints_mandatory() + + +def test_run_pass_twice(): + pass_manager = _setup_pass_manager([PassA(), PassB(), PassB(), PassC()]) + pass_manager.validate_constraints_mandatory() + + +def test_independent_pass(): + pass_manager = _setup_pass_manager( + [ + IndependentPass(), + PassA(), + IndependentPass(), + PassB(), + IndependentPass(), + PassC(), + IndependentPass(), + ] + ) + pass_manager.validate_constraints_mandatory() + + +def test_duplicated_requiring_pass_put_last(): + error_msg = """The following constraints for passes are not met: + - PassC must run after PassB +""" + pass_manager = _setup_pass_manager([PassA(), PassB(), PassC(), PassB()]) + with pytest.raises(RuntimeError, match=re.escape(error_msg)): + pass_manager.validate_constraints_mandatory() + + +def test_two_passes_wrong_order(): + error_msg = """The following constraints for passes are not met: + - PassC must run after PassB +""" + pass_manager = _setup_pass_manager([PassC(), PassB()]) + with pytest.raises(RuntimeError, match=re.escape(error_msg)): + pass_manager.validate_constraints_mandatory() + + +def test_missing_passes(): + error_msg = """The following constraints for passes are not met: + - PassC must run after PassA + - PassC must run after PassB +""" + pass_manager = _setup_pass_manager([PassA(), PassB()]) + with pytest.raises(RuntimeError, match=re.escape(error_msg)): + pass_manager.validate_constraints_mandatory() diff --git a/backends/transforms/decompose_sdpa.py b/backends/transforms/decompose_sdpa.py index d49e0da0c9b..6c36d1803fc 100644 --- a/backends/transforms/decompose_sdpa.py +++ b/backends/transforms/decompose_sdpa.py @@ -7,6 +7,7 @@ # pyre-strict import math +from typing import Set, Type import torch from executorch.exir.pass_base import ExportPass, PassResult @@ -19,6 +20,8 @@ class DecomposeScaledDotProductAttention(ExportPass): Decompose from scaled_dot_product_attention to multiple nodes. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, allow_non_fake_inputs: bool = True) -> None: super().__init__() # With allow_non_fake_inputs=False, we don't get _unsafe_view ops diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index c740515cdcc..1972513d2ef 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -7,6 +7,8 @@ # pyre-strict +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -62,6 +64,8 @@ def remove_noop_view_copy(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: class FuseViewCopyTransform(ExportPass): + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph) graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph) From 243d006412f30dbba52828b03a833d38786d9295 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Fri, 5 Sep 2025 14:26:39 +0200 Subject: [PATCH 2/2] Arm backend: Remove DecomposeLinearVectorNormPass from INT pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DecomposeLinearVectorNormPass is listed in both the transform_for_annotation_pipeline and _tosa_INT_pipeline stages. The latter is redundant because torch.linalg.vector_norm can only run with floating point inputs, i.e., we should always be in a quantized setting when we enter _tosa_INT_pipeline if the operator is present and _transform_for_annotation_pipeline will always run; therefore, remove DecomposeLinearVectorNormPass from _tosa_INT_pipeline. Signed-off-by: Martin Lindström Change-Id: I647687b51298bbb98087914fbbee053436ffb79f --- backends/arm/_passes/arm_pass_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 0a42e51e526..c6530357f3b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -155,7 +155,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(ConvertMmToBmmPass()) - self.add_pass(DecomposeLinearVectorNormPass()) self.add_pass( DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) )