diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py index 31c0c0505cb..a8a76c0a47b 100644 --- a/backends/arm/_passes/add_bias_pass.py +++ b/backends/arm/_passes/add_bias_pass.py @@ -3,13 +3,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.transforms.utils import create_constant_placeholder from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from torch.export.graph_signature import InputKind @@ -19,6 +21,8 @@ class AddBiasPass(ArmPass): The bias is set to zero. """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = (exir_ops.edge.aten.convolution.default,) def call(self, graph_module): diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 8156ca0b89d..81b7b36cc0b 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -7,7 +7,7 @@ import itertools import operator -from typing import cast, List +from typing import cast, List, Set, Type import torch from executorch.backends.arm._passes.arm_pass_utils import create_node @@ -29,6 +29,8 @@ class AnnotateDecomposedMatmulPass(ExportPass): matmul-op (can be mm or bmm). """ + _passes_required_after: Set[Type[ExportPass]] = set() + def _match_partition_to_node( self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node] ) -> torch.fx.Node: diff --git a/backends/arm/_passes/annotate_output_dim_order_pass.py b/backends/arm/_passes/annotate_output_dim_order_pass.py index 08f93383a9c..8dc13326e4a 100644 --- a/backends/arm/_passes/annotate_output_dim_order_pass.py +++ b/backends/arm/_passes/annotate_output_dim_order_pass.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult class AnnotateOutputDimOrderPass(ArmPass): @@ -14,6 +17,8 @@ class AnnotateOutputDimOrderPass(ArmPass): for verifying that the dim order does not change unexpectedly in later passes. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module): output_node = graph_module.graph.output_node() output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module) diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index 085267a174e..c76b5d157a7 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -6,7 +6,8 @@ # pyre-unsafe import traceback -from typing import Optional +from abc import abstractmethod +from typing import List, Optional, Set, Type import torch from executorch.exir.pass_base import ExportPass, NodeMetadata @@ -19,6 +20,36 @@ def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = No super(ArmPass, self).__init__() self.exported_program = exported_program + @property + @abstractmethod + def _passes_required_after(self) -> Set[Type[ExportPass]]: + """The subclass defines passes that must run after it""" + pass + + @staticmethod + def get_required_passes(pass_) -> List[str]: + """ + Returns the list of passes that must be run after this pass, sorted by name. + """ + if hasattr(pass_, "_passes_required_after"): + return sorted([ArmPass.get_name(p) for p in pass_._passes_required_after]) + else: + return [] + + @staticmethod + def get_name(pass_) -> str: + """ + Returns the name of the pass. + """ + if isinstance(pass_, ExportPass): + return pass_.__class__.__name__ + elif hasattr(pass_, "__name__"): + return pass_.__name__ + else: + raise ValueError( + f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute." + ) + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): if not updated: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f49206da67e..c6530357f3b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -7,6 +7,9 @@ # pyre-unsafe + +from collections import defaultdict + import executorch.backends.arm.tosa.dialect # noqa: unused from executorch.backends.arm._passes import ( AddBiasPass, @@ -94,6 +97,7 @@ UnsqueezeScalarPlaceholdersPass, ) +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, TosaSpecification, @@ -115,6 +119,32 @@ def __init__(self, tosa_spec: TosaSpecification) -> None: self.tosa_spec = tosa_spec super().__init__() + def validate_constraints_mandatory(self): + """ + Validates that necessary passes have run before transforming to backend. + + Note that this differs from the original validate_constraints function, which + only checks the order of passes. + """ + passes_to_run = defaultdict(list) + + for current_pass in self.passes: + current_pass_name = ArmPass.get_name(current_pass) + for required_pass_name in ArmPass.get_required_passes(current_pass): + passes_to_run[required_pass_name].append(current_pass_name) + + passes_to_run.pop(current_pass_name, None) + + if len(passes_to_run) > 0: + error_msg = "The following constraints for passes are not met:\n" + for required_pass, requiring_passes in passes_to_run.items(): + for requiring_pass in requiring_passes: + error_msg += ( + f" - {required_pass} must run after {requiring_pass}\n" + ) + + raise RuntimeError(error_msg) + def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module @@ -125,7 +155,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(ConvertMmToBmmPass()) - self.add_pass(DecomposeLinearVectorNormPass()) self.add_pass( DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) ) @@ -175,6 +204,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) + self.validate_constraints_mandatory() return self._transform(exported_program.graph_module) def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: @@ -258,6 +288,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) + self.validate_constraints_mandatory() return self._transform(exported_program.graph_module) def transform_to_backend_pipeline(self, exported_program: ExportedProgram): diff --git a/backends/arm/_passes/broadcast_args_pass.py b/backends/arm/_passes/broadcast_args_pass.py index f125ba13ff4..659e6aca686 100644 --- a/backends/arm/_passes/broadcast_args_pass.py +++ b/backends/arm/_passes/broadcast_args_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( @@ -12,7 +14,7 @@ from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node @@ -22,6 +24,8 @@ class BroadcastArgsPass(ArmPass): This is done when more than one arg needs broadcasting. """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.sub.Tensor, diff --git a/backends/arm/_passes/cast_bool_to_int8_pass.py b/backends/arm/_passes/cast_bool_to_int8_pass.py index 1352671b01e..771b6d9e174 100644 --- a/backends/arm/_passes/cast_bool_to_int8_pass.py +++ b/backends/arm/_passes/cast_bool_to_int8_pass.py @@ -6,6 +6,8 @@ # The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input # If input/output is bool lest add a cast/conversion pass before/after to/from int8. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -15,6 +17,8 @@ class CastBoolToInt8Pass(ExportPass): """Casts the input to int8 if it is not already and casts back the output to the original input dtype.""" + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Tensor, diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index 8052c8fd2ce..d7b2a6b6b43 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -6,6 +6,7 @@ # pyre-unsafe import logging +from typing import Set, Type import torch from executorch.exir.pass_base import ExportPass, PassResult @@ -19,6 +20,8 @@ class CastInt64BuffersToInt32Pass(ExportPass): Cast int64 buffers to int32 if the int64 data is in int32 range. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: torch.export.ExportedProgram): super(CastInt64BuffersToInt32Pass, self).__init__() self.exported_program = exported_program diff --git a/backends/arm/_passes/cast_to_int32_pass.py b/backends/arm/_passes/cast_to_int32_pass.py index c4b009e2b88..2e574568235 100644 --- a/backends/arm/_passes/cast_to_int32_pass.py +++ b/backends/arm/_passes/cast_to_int32_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -12,6 +14,8 @@ class CastToInt32Pass(ExportPass): """Casts the input to int32 if it is not already and casts back the output to the original input dtype.""" + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.bitwise_right_shift.Tensor, diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index 56f674e9066..718c94fc196 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -6,6 +6,8 @@ # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -21,6 +23,8 @@ class Conv1dUnsqueezePass(ExportPass): 3) squeeze the output back down to 3d. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.convolution.default: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_any_default_dim_dims_pass.py b/backends/arm/_passes/convert_any_default_dim_dims_pass.py index 7085f17add0..f4ec0c57b2a 100644 --- a/backends/arm/_passes/convert_any_default_dim_dims_pass.py +++ b/backends/arm/_passes/convert_any_default_dim_dims_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ( # type: ignore[import-not-found] ops as exir_ops, @@ -44,6 +46,8 @@ class ConvertAnyDefaultDimDimsPass(ExportPass): squeeze(dim = [dim1, dim2]) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index ee509c7ebb5..1c6b52b150a 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -6,7 +6,7 @@ # pyre-unsafe import logging -from typing import cast +from typing import cast, Set, Type import torch @@ -50,6 +50,8 @@ class ConvertExpandCopyToRepeatPass(ExportPass): Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions. """ + _passes_required_after: Set[Type[ExportPass]] = set() + expand_copy = exir_ops.edge.aten.expand_copy.default repeat = exir_ops.edge.aten.repeat.default diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index 234e2ecda82..2f46e19005a 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -19,6 +21,8 @@ class ConvertFullLikeToFullPass(ExportPass): Skip layout and device since it's not relevant for our backend. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.full_like.default, diff --git a/backends/arm/_passes/convert_int64_const_ops_to_int32.py b/backends/arm/_passes/convert_int64_const_ops_to_int32.py index 704c89dbd78..9af44f56f11 100644 --- a/backends/arm/_passes/convert_int64_const_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_const_ops_to_int32.py @@ -7,6 +7,7 @@ import logging +from typing import Set, Type import torch from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT @@ -30,6 +31,8 @@ class ConvertInt64ConstOpsToInt32Pass(ExportPass): 5. `torch.tensor` """ + _passes_required_after: Set[Type[ExportPass]] = set() + torch_ops = [ torch.ops.aten.full.default, torch.ops.aten.arange.default, diff --git a/backends/arm/_passes/convert_int64_output_ops_to_int32.py b/backends/arm/_passes/convert_int64_output_ops_to_int32.py index 788201be6c8..d0d29d14e30 100644 --- a/backends/arm/_passes/convert_int64_output_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_output_ops_to_int32.py @@ -7,6 +7,7 @@ import logging +from typing import Set, Type import torch from executorch.backends.arm._passes.arm_pass_utils import ( @@ -44,6 +45,8 @@ class ConvertInt64OutputOpsToInt32Pass(ExportPass): the int32 range. """ + _passes_required_after: Set[Type[ExportPass]] = set() + aten_cast_ops = ( torch.ops.aten.to.dtype, torch.ops.aten.to.dtype_layout, diff --git a/backends/arm/_passes/convert_int_pow_to_mul.py b/backends/arm/_passes/convert_int_pow_to_mul.py index f22a2fd0b3c..8f9b3a9cb4b 100644 --- a/backends/arm/_passes/convert_int_pow_to_mul.py +++ b/backends/arm/_passes/convert_int_pow_to_mul.py @@ -5,8 +5,11 @@ # pyre-unsafe +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass class ConvertIntPowToMuls(ArmPass): @@ -16,6 +19,8 @@ class ConvertIntPowToMuls(ArmPass): Needs to be run before doing scalar to tensor conversion. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.pow.Tensor_Scalar: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_minmax_pass.py b/backends/arm/_passes/convert_minmax_pass.py index 9f409632c20..2cf59ab2300 100644 --- a/backends/arm/_passes/convert_minmax_pass.py +++ b/backends/arm/_passes/convert_minmax_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -29,6 +31,8 @@ class ConvertMinMaxPass(ExportPass): squeeze(dim = [dim1, dim2]) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def check_argmax(self, node): """ Raises a RuntimeError if the argmax value returned by the min/max op is used in the graph. diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 67bd9d73e81..7578c07ca53 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch.fx from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -19,6 +21,8 @@ class ConvertSplitToSlicePass(ExportPass): Replace a split operation with many slice operations. """ + _passes_required_after: Set[Type[ExportPass]] = set() + split_ops = ( exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.split_copy.Tensor, diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index 889dbe74172..9c5d26a7c22 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -6,6 +6,8 @@ # pyre-unsafe +from typing import Set, Type + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -15,6 +17,8 @@ class ConvertSqueezesToViewPass(ExportPass): Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.squeeze_copy.dims, diff --git a/backends/arm/_passes/convert_to_clamp.py b/backends/arm/_passes/convert_to_clamp.py index 8f2c9b16f9a..3f8cac30b96 100644 --- a/backends/arm/_passes/convert_to_clamp.py +++ b/backends/arm/_passes/convert_to_clamp.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Set, Tuple, Type from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -24,6 +24,8 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]: class ConvertToClampPass(ExportPass): + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in edge_operators: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 1d92dd68c4a..30c5c137482 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -5,8 +5,11 @@ # pyre-unsafe +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_acosh_op = exir_ops.edge.aten.acosh.default @@ -19,6 +22,8 @@ class DecomposeAcoshPass(ArmPass): acosh(x) = log(x + sqrt((x-1)(x+1)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_acosh_op: diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index abfcc8e3945..f1623b4aca7 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -4,12 +4,14 @@ # LICENSE file in the root directory of this source tree. from math import ceil, floor +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ops = (exir_ops.edge.aten._adaptive_avg_pool2d.default,) aten_ops = (torch.ops.aten.adaptive_avg_pool2d.default,) @@ -41,6 +43,8 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass): The output is of size output_size_h x output_size_w for any input. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops): return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_addmm_pass.py b/backends/arm/_passes/decompose_addmm_pass.py index b59a8cb02d3..142f3143f38 100644 --- a/backends/arm/_passes/decompose_addmm_pass.py +++ b/backends/arm/_passes/decompose_addmm_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case @@ -36,6 +39,8 @@ def get_ops(op): class DecomposeAddmmPass(ArmPass): """Decomposes the addmm operator into tensor multiplication and addition.""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [edge_addmm, aten_addmm]: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index e067f17b0ca..c083cc669c2 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -7,11 +7,13 @@ import logging from math import pi +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_asin_op = (exir_ops.edge.aten.asin.default,) @@ -54,6 +56,8 @@ class DecomposeAsinAndAcosPass(ArmPass): """ + _passes_required_after: Set[Type[ExportPass]] = set() + def _build_polynomial( self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str] ) -> torch.Tensor: diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index a0b78c51a77..b8f7300beb5 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -6,8 +6,11 @@ # pyre-unsafe +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_asinh_op = (exir_ops.edge.aten.asinh.default,) @@ -20,6 +23,8 @@ class DecomposeAsinhPass(ArmPass): asinh(x) = log(x + sqrt(x^2 + 1)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in edge_asinh_op: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index 57b9dde5216..7faef26a245 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -5,9 +5,11 @@ import logging from math import pi +from typing import Set, Type from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_atan = exir_ops.edge.aten.atan.default # MI case @@ -35,6 +37,8 @@ def _get_atan_ops(op): class DecomposeAtanPass(ArmPass): """Decomposes the atan operator into a rational (Padé) approximation.""" + _passes_required_after: Set[Type[ExportPass]] = set() + def _rational_approximation(self, z, ops, meta): """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index dfdad41e556..d06598923b3 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -3,8 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_atanh = exir_ops.edge.aten.atanh.default # MI case @@ -30,6 +33,8 @@ class DecomposeAtanhPass(ArmPass): atanh(x) = 0.5 * log((1 + x) / (1 - x)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op is not edge_atanh: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_avg_pool2d.py b/backends/arm/_passes/decompose_avg_pool2d.py index 21ed6b518c7..0240661053b 100644 --- a/backends/arm/_passes/decompose_avg_pool2d.py +++ b/backends/arm/_passes/decompose_avg_pool2d.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, @@ -34,7 +36,7 @@ def get_decomposition(op) -> tuple: class DecomposeAvgPool2d(ExportPass): - """ """ + _passes_required_after: Set[Type[ExportPass]] = set() def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops): diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py index 5fdb8db2d7c..82937241369 100644 --- a/backends/arm/_passes/decompose_batch_norm_no_stats.py +++ b/backends/arm/_passes/decompose_batch_norm_no_stats.py @@ -6,12 +6,13 @@ # pyre-unsafe import operator +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult class DecomposeBatchNormNoStatsPass(ArmPass): @@ -33,6 +34,8 @@ class DecomposeBatchNormNoStatsPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 bn_ops = ( exir_ops.edge.aten._native_batch_norm_legit.no_stats, diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index a94cf9ecff0..b71ca388651 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -3,8 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_cosh = exir_ops.edge.aten.cosh.default @@ -19,6 +22,8 @@ class DecomposeCoshPass(ArmPass): """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_cosh: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_cosine_similarity_pass.py b/backends/arm/_passes/decompose_cosine_similarity_pass.py index 9978e653408..e2ab01b345f 100644 --- a/backends/arm/_passes/decompose_cosine_similarity_pass.py +++ b/backends/arm/_passes/decompose_cosine_similarity_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass @@ -22,6 +24,8 @@ class DecomposeCosineSimilarityPass(ExportPass): out = div(dot, denom) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in torch_cosine_similarity: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_cumsum_pass.py b/backends/arm/_passes/decompose_cumsum_pass.py index 155ccd11594..04e6275c6c1 100644 --- a/backends/arm/_passes/decompose_cumsum_pass.py +++ b/backends/arm/_passes/decompose_cumsum_pass.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from math import prod +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass @@ -12,7 +13,7 @@ from executorch.backends.transforms.utils import create_constant_placeholder from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from torch.export.graph_signature import InputKind @@ -39,6 +40,8 @@ class DecomposeCumsumPass(ArmPass): And the convolution is applied over dimension H. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module): graph = graph_module.graph targets = (exir_ops.edge.aten.cumsum.default, torch.ops.aten.cumsum.default) diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index 893531dac69..b6e289ff049 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -6,6 +6,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -37,6 +39,8 @@ class DecomposeDivPass(ExportPass): y = mul(a,x) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_elu_pass.py b/backends/arm/_passes/decompose_elu_pass.py index 743f1b46f4d..ba3d32b7529 100644 --- a/backends/arm/_passes/decompose_elu_pass.py +++ b/backends/arm/_passes/decompose_elu_pass.py @@ -3,8 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_elu_ops = (exir_ops.edge.aten.elu.default,) @@ -55,6 +58,8 @@ class DecomposeEluPass(ArmPass): - exir_ops.edge.aten.mul.Scalar """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in edge_elu_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_embedding_pass.py b/backends/arm/_passes/decompose_embedding_pass.py index 6de971f402f..5b2ad27eaf6 100644 --- a/backends/arm/_passes/decompose_embedding_pass.py +++ b/backends/arm/_passes/decompose_embedding_pass.py @@ -8,6 +8,7 @@ import logging from math import prod +from typing import Set, Type import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -33,6 +34,8 @@ class DecomposeEmbeddingPass(ExportPass): i = indices is expected to be int32 before this pass """ + _passes_required_after: Set[Type[ExportPass]] = set() + aten_ops = (torch.ops.aten.embedding.default,) edge_ops = (exir_ops.edge.aten.embedding.default,) diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 5b1b90495b5..21d3c975de3 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -3,8 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_expm1_ops = (exir_ops.edge.aten.expm1.default,) # MI case @@ -68,6 +71,8 @@ class DecomposeExpm1Pass(ArmPass): - exir_ops.edge.aten.logical_and.default """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in edge_expm1_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 6e72175e68b..ef6a4753b8c 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops @@ -77,6 +79,8 @@ class DecomposeGeluPass(ExportPass): %op7 = mul(%op6, %FULL_0_5) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in torch_gelu + edge_gelu: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_glu_pass.py b/backends/arm/_passes/decompose_glu_pass.py index 183dc89cf61..6b53609c951 100644 --- a/backends/arm/_passes/decompose_glu_pass.py +++ b/backends/arm/_passes/decompose_glu_pass.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For FP case @@ -36,6 +39,8 @@ def get_ops(op): class DecomposeGluPass(ArmPass): """Decomposes the GLU operator into hadamard product and sigmoid.""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [edge_glu, aten_glu]: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_grouped_conv.py b/backends/arm/_passes/decompose_grouped_conv.py index ce9fe9c9937..2f0d7b4d72c 100644 --- a/backends/arm/_passes/decompose_grouped_conv.py +++ b/backends/arm/_passes/decompose_grouped_conv.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from copy import copy +from typing import Set, Type import torch from executorch.backends.arm._passes.quant_args import QuantArgs @@ -33,6 +34,8 @@ class DecomposeGroupedConv(ExportPass): x = cat(x1, x2) """ + _passes_required_after: Set[Type[ExportPass]] = set() + @staticmethod def _get_decomposition(op): match op: diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index c6cb1b05e40..7f0d7fdeafd 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -6,12 +6,13 @@ # pyre-unsafe import operator +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult def get_group_norm_decomposition(op) -> tuple: @@ -57,6 +58,8 @@ class DecomposeGroupNormPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index e6cbdfb91a0..0710ed37b45 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -6,12 +6,13 @@ # pyre-unsafe import operator +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult def get_layer_norm_decomposition(op) -> tuple: @@ -56,6 +57,8 @@ class DecomposeLayerNormPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: if node.op != "call_function" or node.target not in ( diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index e896cc584be..8ae13a76eb0 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -6,9 +6,12 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ops = (exir_ops.edge.aten.leaky_relu.default,) torch_ops = (torch.ops.aten.leaky_relu.default,) @@ -46,6 +49,8 @@ class DecomposeLeakyReLUPass(ArmPass): %op5 = add(%op1,%op4) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_ops + torch_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 9f036c0524f..17441981654 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass @@ -28,6 +30,8 @@ class DecomposeLinearVectorNormPass(ExportPass): dtype prior, but we dont know this from FX graph. """ + _passes_required_after: Set[Type[ExportPass]] = set() + torch_linalg_vector_norm = (torch.ops.aten.linalg_vector_norm.default,) def call_operator(self, op, args, kwargs, meta): diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index 3d154d9b81e..70268c77a1d 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import numpy as np from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( @@ -12,7 +14,7 @@ get_first_fake_tensor, ) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult class DecomposeLinearPass(ArmPass): @@ -25,6 +27,8 @@ class DecomposeLinearPass(ArmPass): output = view(conv2d) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module): for node in graph_module.graph.nodes: if node.op != "call_function": diff --git a/backends/arm/_passes/decompose_logit_pass.py b/backends/arm/_passes/decompose_logit_pass.py index 40e2b22cb54..a82650f0b9e 100644 --- a/backends/arm/_passes/decompose_logit_pass.py +++ b/backends/arm/_passes/decompose_logit_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For FP case @@ -60,6 +63,8 @@ class DecomposeLogitPass(ArmPass): log(y * reciprocal((-1) * y + 1)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [edge_logit, aten_logit]: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_masked_fill.py b/backends/arm/_passes/decompose_masked_fill.py index fbf3079c92b..ced58aa3920 100644 --- a/backends/arm/_passes/decompose_masked_fill.py +++ b/backends/arm/_passes/decompose_masked_fill.py @@ -6,10 +6,13 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ops = (exir_ops.edge.aten.masked_fill.Scalar,) @@ -37,6 +40,8 @@ class DecomposeMaskedFill(ArmPass): Decomposed to a where and a full_like operator. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops): return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py index ff6db260099..1df062ddb57 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py @@ -6,9 +6,11 @@ # pyre-unsafe import operator +from typing import Set, Type from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # We'll decompose only the EXIR edge max_pool2d ops when dilation > 1 EDGE_MAXPOOL2D = ( @@ -22,6 +24,8 @@ class DecomposeMaxPool2DPass(ArmPass): Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): # Only intercept EXIR edge max_pool2d ops if op not in EDGE_MAXPOOL2D: diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index a78514b6af5..716924dfbf2 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -5,12 +5,14 @@ from copy import copy from math import prod +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass def get_meandim_decomposition(op) -> tuple: @@ -62,6 +64,8 @@ class DecomposeMeanDimPass(ArmPass): x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, graph_module, tosa_spec): super().__init__() self._graph_module = graph_module diff --git a/backends/arm/_passes/decompose_ne_pass.py b/backends/arm/_passes/decompose_ne_pass.py index 16443d5d2fb..3bd4f4540bb 100644 --- a/backends/arm/_passes/decompose_ne_pass.py +++ b/backends/arm/_passes/decompose_ne_pass.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ne_ops = (exir_ops.edge.aten.ne.Tensor,) aten_ne_ops = (torch.ops.aten.ne.Tensor, torch.ops.aten.ne_.Tensor) @@ -53,6 +56,8 @@ class DecomposeNotEqualPass(ArmPass): - followed by aten.logical_not.default or its edge equivalent """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_ne_ops + aten_ne_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_round_pass.py b/backends/arm/_passes/decompose_round_pass.py index edfa3817064..35d36e80396 100644 --- a/backends/arm/_passes/decompose_round_pass.py +++ b/backends/arm/_passes/decompose_round_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass from torch._ops import OpOverload @@ -56,6 +59,8 @@ class DecomposeRoundPass(ArmPass): %result = where(%is_non_negative, %floor, %ceil) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (exir_ops.edge.aten.round.default, torch.ops.aten.round.default): return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_select.py b/backends/arm/_passes/decompose_select.py index 99c89f474ea..9c65cd1c0a8 100644 --- a/backends/arm/_passes/decompose_select.py +++ b/backends/arm/_passes/decompose_select.py @@ -6,6 +6,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -20,6 +22,8 @@ class DecomposeSelectPass(ExportPass): This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/decompose_sign_pass.py b/backends/arm/_passes/decompose_sign_pass.py index 1038ff0f3fa..c4cb964316d 100644 --- a/backends/arm/_passes/decompose_sign_pass.py +++ b/backends/arm/_passes/decompose_sign_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case @@ -42,6 +45,8 @@ def get_ops(op): class DecomposeSignPass(ArmPass): """Decomposes the sign operator into a sequence of operations that are supported by the Arm backend.""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_sign, aten_sign): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_silu_pass.py b/backends/arm/_passes/decompose_silu_pass.py index 68ebb3f4515..cb7b55be520 100644 --- a/backends/arm/_passes/decompose_silu_pass.py +++ b/backends/arm/_passes/decompose_silu_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass @@ -22,6 +24,8 @@ class DecomposeSiluPass(ExportPass): y = mul(a,x) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (aten_silu_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index 7192eb9bf74..473a263e9a5 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -4,8 +4,11 @@ # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case @@ -24,6 +27,8 @@ class DecomposeSinhPass(ArmPass): and scalar multiplication. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op is not edge_sinh: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_softmax_pass.py b/backends/arm/_passes/decompose_softmax_pass.py index a735501f711..47f448ae851 100644 --- a/backends/arm/_passes/decompose_softmax_pass.py +++ b/backends/arm/_passes/decompose_softmax_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -62,6 +64,8 @@ class DecomposeSoftmaxPass(ExportPass): (in logsoftmax case: %op7 = log(%op6)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in torch_softmax + edge_softmax: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py index b6f5e11b66b..5e704585eb0 100644 --- a/backends/arm/_passes/decompose_softmax_unstable_pass.py +++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py @@ -5,9 +5,12 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For BI case torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) @@ -57,6 +60,8 @@ class DecomposeSoftmaxUnstablePass(ArmPass): (in logsoftmax case: %op5 = log(%op4)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in torch_softmax + edge_softmax: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index 547d0091e90..c93686901d5 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Tuple, Union +from typing import Set, Tuple, Type, Union import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -27,6 +27,7 @@ def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]: class DecomposeSqrtPass(ExportPass): + _passes_required_after: Set[Type[ExportPass]] = set() def call_operator(self, op, args, kwargs, meta): """ diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 52b9c10c49f..16027ccec2b 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -40,6 +42,8 @@ class DecomposeSumPass(ExportPass): view(shape = squeezed_shape) -> squeezed_shape """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.sum.dim_IntList, diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 15872738f3e..f8396da0420 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -7,10 +7,13 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass def get_var_decomposition(op) -> tuple: @@ -47,6 +50,8 @@ class DecomposeVarPass(ArmPass): y = div(sum, max(0, N-correction)) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in ( exir_ops.edge.aten.var.correction, diff --git a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py index 17a682c0a8e..9d704520302 100644 --- a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py +++ b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py @@ -6,10 +6,13 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass def _get_decorated_ops(op): @@ -40,6 +43,8 @@ class DecorateFp32toInt32CastingPass(ArmPass): output = to_dim_order_copy(decorated_x, dtype=torch.int32) """ + _passes_required_after: Set[Type[ExportPass]] = set() + targets = [ exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 491b404f0a4..714543d3908 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -8,7 +8,7 @@ import copy -from typing import cast, Dict, Set, Tuple +from typing import cast, Dict, Set, Tuple, Type from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( @@ -100,6 +100,8 @@ class FoldAndAnnotateQParamsPass(ArmPass): """ + _passes_required_after: Set[Type[ExportPass]] = set() + def fold_and_annotate_arg( self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int ) -> None: @@ -210,6 +212,8 @@ class QuantizeOperatorArguments(ExportPass): - Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: GraphModule) -> PassResult: modified = False # Loop over the graph nodes and find full.default nodes. @@ -257,6 +261,8 @@ class RetraceFoldedDtypesPass(ExportPass): the output type of that matches the type of the output_qparams. """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops: Set[EdgeOpOverload] = { exir_ops.edge.aten.sum.dim_IntList, } diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py index 2dbdfa84cec..be884585d4d 100644 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ b/backends/arm/_passes/fuse_batchnorm2d_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -28,6 +30,8 @@ class FuseBatchnorm2DPass(ExportPass): the weights and bias of the convolution and removing the batchnorm. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram): self.exported_program = exported_program super().__init__() diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index f49565e3c38..07f3a4af245 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging +from typing import Set, Type import torch._export.utils import torch.fx @@ -41,6 +42,8 @@ def f(): return x """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program @@ -168,6 +171,8 @@ def f(node_name_pre_computed): return node_name_pre_computed """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = [ exir_ops.edge.aten.full.default, exir_ops.edge.aten.arange.start_step, diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index 5631e2f32e9..cf1177a0448 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -5,6 +5,7 @@ import hashlib from collections import defaultdict +from typing import Set, Type import torch from executorch.backends.arm._passes.arm_pass_utils import ( @@ -27,6 +28,8 @@ class FuseEqualPlaceholdersPass(ExportPass): with multiple users, using a cache for faster comparison. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram): self.exported_program = exported_program super().__init__() diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index 46a7d7f6f98..d39d7135f9c 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import Q_OPS @@ -14,6 +16,8 @@ class FuseQuantizedActivationPass(ExportPass): + _passes_required_after: Set[Type[ExportPass]] = set() + @staticmethod def _is_fuseable_quantized_activation(node: Node): """Fuse activations that have a 0 lower bound and quantized with a qmin zero-point""" diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 7f75aecf24c..100ac03c2b0 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from copy import copy -from typing import cast +from typing import cast, Set, Type from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs @@ -24,6 +24,8 @@ class InsertRescalePass(ExportPass): in the fake implementation of. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule): dq_args = QuantArgs.from_operator(node.target, node.args) q_args = QuantArgs.from_operator(user.target, user.args) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index fb5d7de5e12..d838ddc823d 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -6,7 +6,7 @@ # pyre-unsafe from itertools import chain -from typing import Callable, cast, Dict, Iterator, Set +from typing import Callable, cast, Dict, Iterator, Set, Type import torch from executorch.backends.arm._passes.arm_pass_utils import create_node @@ -117,6 +117,8 @@ class InsertTableOpsPass(ExportPass): which will be used to produce the table values in operators/op_table.py. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program diff --git a/backends/arm/_passes/match_arg_dtype_pass.py b/backends/arm/_passes/match_arg_dtype_pass.py index e7bf3b2d60e..d482614b03f 100644 --- a/backends/arm/_passes/match_arg_dtype_pass.py +++ b/backends/arm/_passes/match_arg_dtype_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import create_node, get_node_arg from executorch.exir.dialects._ops import ops as exir_ops @@ -38,6 +40,8 @@ class MatchArgDtypePass(ExportPass): """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.where.self} def call(self, graph_module: torch.fx.GraphModule): diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index d6cdfacb612..c411f3b8083 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -7,7 +7,7 @@ # pyre-unsafe -from typing import cast +from typing import cast, Set, Type from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -36,6 +36,8 @@ class MatchArgRanksPass(ExportPass): input2 = shape(1, 3, 1) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program): super().__init__() self.exported_program = exported_program diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 69d8573013e..6be0b9e2ac4 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -6,6 +6,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import ( create_node, @@ -28,6 +30,8 @@ class ConvertMmToBmmPass(ExportPass): 3) Squeeze output tensor to rank 2. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified_graph = False graph = graph_module.graph diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index 623517aac59..55c4f71f0a8 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -7,6 +7,7 @@ # pyre-unsafe import logging +from typing import Set, Type from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -17,6 +18,8 @@ class RemoveNoopPass(ExportPass): """Remove no-ops from graph_module""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in ( exir_ops.edge.dim_order_ops._clone_dim_order.default, diff --git a/backends/arm/_passes/replace_inf_values_pass.py b/backends/arm/_passes/replace_inf_values_pass.py index 8c721eda3d8..506030d82d7 100644 --- a/backends/arm/_passes/replace_inf_values_pass.py +++ b/backends/arm/_passes/replace_inf_values_pass.py @@ -7,6 +7,8 @@ # This pass is based on backends/qualcomm/_passes/replace_inf_values.py # with some modification to replaced inf values. +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass, PassResult @@ -16,6 +18,8 @@ class ReplaceInfValues(ExportPass): Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self): super(ReplaceInfValues, self).__init__() diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index 249eb9ffd41..f6ef056f677 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import Dict, Union +from typing import Dict, Set, Type, Union import torch from executorch.backends.transforms.replace_scalar_with_tensor import ( @@ -15,6 +15,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass # Operators that are included for both TOSA profiles @@ -56,6 +57,8 @@ class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass): + _passes_required_after: Set[Type[ExportPass]] = set() + scalar_to_tensor_ops = _common_ops | { exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor, @@ -66,6 +69,8 @@ def __init__(self): class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass): + _passes_required_after: Set[Type[ExportPass]] = set() + scalar_to_tensor_ops = _common_ops def __init__(self): diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 89468bff1ff..bb2a02cc679 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import cast, Union +from typing import cast, Set, Type, Union import torch from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor @@ -22,6 +22,8 @@ class ScalarsToAttributePass(ExportPass): to attribute Nodes that output the same value. """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = [ torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor, diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index e87d65c450f..5eb77dc56df 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -5,7 +5,7 @@ # pyre-unsafe -from typing import cast, TypeAlias +from typing import cast, Set, Type, TypeAlias import torch.fx from executorch.backends.arm._passes.arm_pass_utils import create_node @@ -185,6 +185,8 @@ class SizeAdjustInputPass(ExportPass): input. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph modified_graph = False diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index ac16cbaf8cb..dcbdfb03f7b 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -7,6 +7,7 @@ import logging +from typing import Set, Type import torch from executorch.backends.arm._passes.annotate_decomposed_matmul import ( @@ -48,6 +49,14 @@ class ToTosaMemoryFormatPass(ExportPass): The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape. """ + _passes_required_after: Set[Type[ExportPass]] = set() + + NHWC_order = (0, 2, 3, 1) + NHWC_inverse_order = (0, 3, 1, 2) + HWCM_order = (2, 3, 0, 1) + NNHWC_order = (0, 1, 3, 4, 2) + NNHWC_inverse_order = (0, 1, 4, 2, 3) + def __init__(self, exported_program: ExportedProgram) -> None: self.exported_program = exported_program super().__init__() diff --git a/backends/arm/_passes/unsqueeze_before_repeat_pass.py b/backends/arm/_passes/unsqueeze_before_repeat_pass.py index 01983baa9ab..66286b6a954 100644 --- a/backends/arm/_passes/unsqueeze_before_repeat_pass.py +++ b/backends/arm/_passes/unsqueeze_before_repeat_pass.py @@ -1,9 +1,11 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe +from typing import Set, Type + import torch import torch.fx from executorch.backends.arm._passes.arm_pass_utils import ( @@ -29,6 +31,8 @@ class UnsqueezeBeforeRepeatPass(ExportPass): repeat(multiples) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified_graph = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py index ccae9b503cf..d3932dd1217 100644 --- a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py +++ b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py @@ -5,6 +5,8 @@ # pyre-unsafe +from typing import Set, Type + import torch from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import is_buffer, is_param @@ -16,6 +18,8 @@ class UnsqueezeScalarPlaceholdersPass(ExportPass): This pass unsqueezes the placeholders to make sure shape is at least (1,). """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program): self.exported_program = exported_program super().__init__() diff --git a/backends/arm/test/misc/test_pass_required_order.py b/backends/arm/test/misc/test_pass_required_order.py new file mode 100644 index 00000000000..2745d25a498 --- /dev/null +++ b/backends/arm/test/misc/test_pass_required_order.py @@ -0,0 +1,95 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from typing import List, Set, Type + +import pytest +from executorch.backends.arm._passes.arm_pass_manager import ArmPass, ArmPassManager +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.exir.pass_base import ExportPass + + +class PassC(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = set() + + +class PassB(ArmPass): + _passes_required_after = {PassC} + + +class PassA(ArmPass): + _passes_required_after = {PassB, PassC} + + +class IndependentPass(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = set() + + +def _setup_pass_manager(passes: List[ArmPass] | None = None): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.00+INT") + pass_manager = ArmPassManager(tosa_spec) + if passes is not None: + for p in passes: + pass_manager.add_pass(p) + return pass_manager + + +def test_no_passes(): + pass_manager = _setup_pass_manager() + pass_manager.validate_constraints_mandatory() + + +def test_correct_order(): + pass_manager = _setup_pass_manager([PassA(), PassB(), PassC()]) + pass_manager.validate_constraints_mandatory() + + +def test_run_pass_twice(): + pass_manager = _setup_pass_manager([PassA(), PassB(), PassB(), PassC()]) + pass_manager.validate_constraints_mandatory() + + +def test_independent_pass(): + pass_manager = _setup_pass_manager( + [ + IndependentPass(), + PassA(), + IndependentPass(), + PassB(), + IndependentPass(), + PassC(), + IndependentPass(), + ] + ) + pass_manager.validate_constraints_mandatory() + + +def test_duplicated_requiring_pass_put_last(): + error_msg = """The following constraints for passes are not met: + - PassC must run after PassB +""" + pass_manager = _setup_pass_manager([PassA(), PassB(), PassC(), PassB()]) + with pytest.raises(RuntimeError, match=re.escape(error_msg)): + pass_manager.validate_constraints_mandatory() + + +def test_two_passes_wrong_order(): + error_msg = """The following constraints for passes are not met: + - PassC must run after PassB +""" + pass_manager = _setup_pass_manager([PassC(), PassB()]) + with pytest.raises(RuntimeError, match=re.escape(error_msg)): + pass_manager.validate_constraints_mandatory() + + +def test_missing_passes(): + error_msg = """The following constraints for passes are not met: + - PassC must run after PassA + - PassC must run after PassB +""" + pass_manager = _setup_pass_manager([PassA(), PassB()]) + with pytest.raises(RuntimeError, match=re.escape(error_msg)): + pass_manager.validate_constraints_mandatory() diff --git a/backends/transforms/decompose_sdpa.py b/backends/transforms/decompose_sdpa.py index d49e0da0c9b..6c36d1803fc 100644 --- a/backends/transforms/decompose_sdpa.py +++ b/backends/transforms/decompose_sdpa.py @@ -7,6 +7,7 @@ # pyre-strict import math +from typing import Set, Type import torch from executorch.exir.pass_base import ExportPass, PassResult @@ -19,6 +20,8 @@ class DecomposeScaledDotProductAttention(ExportPass): Decompose from scaled_dot_product_attention to multiple nodes. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, allow_non_fake_inputs: bool = True) -> None: super().__init__() # With allow_non_fake_inputs=False, we don't get _unsafe_view ops diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index c740515cdcc..1972513d2ef 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -7,6 +7,8 @@ # pyre-strict +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -62,6 +64,8 @@ def remove_noop_view_copy(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: class FuseViewCopyTransform(ExportPass): + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph) graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph)