diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 81b7b36cc0b..6b89b0c3c4a 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -11,6 +11,9 @@ import torch from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops @@ -29,7 +32,7 @@ class AnnotateDecomposedMatmulPass(ExportPass): matmul-op (can be mm or bmm). """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {FoldAndAnnotateQParamsPass} def _match_partition_to_node( self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node] diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index 718c94fc196..b228da6766f 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -8,6 +8,9 @@ from typing import Set, Type +from executorch.backends.arm._passes.add_bias_pass import AddBiasPass +from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -23,7 +26,7 @@ class Conv1dUnsqueezePass(ExportPass): 3) squeeze the output back down to 3d. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {AddBiasPass, SizeAdjustInputPass} def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.convolution.default: diff --git a/backends/arm/_passes/convert_any_default_dim_dims_pass.py b/backends/arm/_passes/convert_any_default_dim_dims_pass.py index f4ec0c57b2a..8c8e5086b6d 100644 --- a/backends/arm/_passes/convert_any_default_dim_dims_pass.py +++ b/backends/arm/_passes/convert_any_default_dim_dims_pass.py @@ -6,6 +6,9 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.convert_squeezes_to_view import ( + ConvertSqueezesToViewPass, +) from executorch.exir.dialects._ops import ( # type: ignore[import-not-found] ops as exir_ops, ) @@ -46,7 +49,7 @@ class ConvertAnyDefaultDimDimsPass(ExportPass): squeeze(dim = [dim1, dim2]) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass} def call(self, graph_module: torch.fx.GraphModule): modified = False diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 1c6b52b150a..83b47d31755 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -10,6 +10,9 @@ import torch +from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( + UnsqueezeBeforeRepeatPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -50,7 +53,7 @@ class ConvertExpandCopyToRepeatPass(ExportPass): Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass} expand_copy = exir_ops.edge.aten.expand_copy.default repeat = exir_ops.edge.aten.repeat.default diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index 2f46e19005a..06822a4abcf 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -5,11 +5,14 @@ from typing import Set, Type +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class ConvertFullLikeToFullPass(ExportPass): +class ConvertFullLikeToFullPass(ArmPass): """As per the full_like pytorch documentation, `torch.full_like(input, fill_value)` is equivalent to `torch.full(input.size(), @@ -21,7 +24,7 @@ class ConvertFullLikeToFullPass(ExportPass): Skip layout and device since it's not relevant for our backend. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} def call_operator(self, op, args, kwargs, meta): if op not in [ diff --git a/backends/arm/_passes/convert_int64_const_ops_to_int32.py b/backends/arm/_passes/convert_int64_const_ops_to_int32.py index 9af44f56f11..2bf305a13f6 100644 --- a/backends/arm/_passes/convert_int64_const_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_const_ops_to_int32.py @@ -31,7 +31,7 @@ class ConvertInt64ConstOpsToInt32Pass(ExportPass): 5. `torch.tensor` """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} torch_ops = [ torch.ops.aten.full.default, diff --git a/backends/arm/_passes/convert_minmax_pass.py b/backends/arm/_passes/convert_minmax_pass.py index 2cf59ab2300..f1c81dbc41e 100644 --- a/backends/arm/_passes/convert_minmax_pass.py +++ b/backends/arm/_passes/convert_minmax_pass.py @@ -6,6 +6,9 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.convert_squeezes_to_view import ( + ConvertSqueezesToViewPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -31,7 +34,7 @@ class ConvertMinMaxPass(ExportPass): squeeze(dim = [dim1, dim2]) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass} def check_argmax(self, node): """ diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index 9c5d26a7c22..70f4625f0ff 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -8,6 +8,8 @@ from typing import Set, Type +from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -17,7 +19,7 @@ class ConvertSqueezesToViewPass(ExportPass): Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransform} def call_operator(self, op, args, kwargs, meta): if op not in [ diff --git a/backends/arm/_passes/convert_to_clamp.py b/backends/arm/_passes/convert_to_clamp.py index 3f8cac30b96..0199d6798bc 100644 --- a/backends/arm/_passes/convert_to_clamp.py +++ b/backends/arm/_passes/convert_to_clamp.py @@ -5,6 +5,10 @@ from typing import Set, Tuple, Type +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + QuantizeOperatorArguments, +) + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -24,7 +28,7 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]: class ConvertToClampPass(ExportPass): - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {QuantizeOperatorArguments} def call_operator(self, op, args, kwargs, meta): if op not in edge_operators: diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 30c5c137482..509849fce4e 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -8,6 +8,13 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass # noqa +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorArgPassTOSAMI, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -22,7 +29,13 @@ class DecomposeAcoshPass(ArmPass): acosh(x) = log(x + sqrt((x-1)(x+1)) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSqrtPass, + InsertTableOpsPass, + MatchArgRanksPass, + ReplaceScalarWithTensorArgPassTOSAMI, + MatchArgDtypePass, + } def call_operator(self, op, args, kwargs, meta, updated=False): diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index f1623b4aca7..52ddb77151d 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -9,6 +9,7 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -43,7 +44,7 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass): The output is of size output_size_h x output_size_w for any input. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2d} def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops): diff --git a/backends/arm/_passes/decompose_addmm_pass.py b/backends/arm/_passes/decompose_addmm_pass.py index 142f3143f38..a95c1cc7fec 100644 --- a/backends/arm/_passes/decompose_addmm_pass.py +++ b/backends/arm/_passes/decompose_addmm_pass.py @@ -8,6 +8,9 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass # noqa from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -39,7 +42,11 @@ def get_ops(op): class DecomposeAddmmPass(ArmPass): """Decomposes the addmm operator into tensor multiplication and addition.""" - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + ConvertMmToBmmPass, + MatchArgRanksPass, + MatchArgDtypePass, + } def call_operator(self, op, args, kwargs, meta): if op not in [edge_addmm, aten_addmm]: diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index c083cc669c2..5b1c575e9c9 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -12,6 +12,16 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( + ConvertFullLikeToFullPass, +) +from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorArgPassTOSAMI, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -56,7 +66,14 @@ class DecomposeAsinAndAcosPass(ArmPass): """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSqrtPass, + DecomposeDivPass, + ConvertFullLikeToFullPass, + MatchArgRanksPass, + MatchArgDtypePass, + ReplaceScalarWithTensorArgPassTOSAMI, + } def _build_polynomial( self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str] diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index b8f7300beb5..088230ca4b2 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -9,6 +9,13 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorArgPassTOSAMI, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -23,7 +30,13 @@ class DecomposeAsinhPass(ArmPass): asinh(x) = log(x + sqrt(x^2 + 1)) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSqrtPass, + InsertTableOpsPass, + MatchArgRanksPass, + ReplaceScalarWithTensorArgPassTOSAMI, + MatchArgDtypePass, + } def call_operator(self, op, args, kwargs, meta): if op not in edge_asinh_op: diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index 7faef26a245..03ed62e7870 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -8,6 +8,12 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorArgPassTOSAMI, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -37,7 +43,12 @@ def _get_atan_ops(op): class DecomposeAtanPass(ArmPass): """Decomposes the atan operator into a rational (Padé) approximation.""" - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + MatchArgDtypePass, + ReplaceScalarWithTensorArgPassTOSAMI, + } def _rational_approximation(self, z, ops, meta): """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index d06598923b3..2c8347e7e9f 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -6,6 +6,12 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorArgPassTOSAMI, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -33,7 +39,12 @@ class DecomposeAtanhPass(ArmPass): atanh(x) = 0.5 * log((1 + x) / (1 - x)) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + MatchArgDtypePass, + ReplaceScalarWithTensorArgPassTOSAMI, + } def call_operator(self, op, args, kwargs, meta): if op is not edge_atanh: diff --git a/backends/arm/_passes/decompose_avg_pool2d.py b/backends/arm/_passes/decompose_avg_pool2d.py index 0240661053b..bbb8ceba129 100644 --- a/backends/arm/_passes/decompose_avg_pool2d.py +++ b/backends/arm/_passes/decompose_avg_pool2d.py @@ -7,6 +7,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, ) @@ -32,11 +33,11 @@ def get_decomposition(op) -> tuple: torch.ops.aten.avg_pool2d.default, torch.ops.aten.mul.Tensor, ) - raise RuntimeError(f"Can't get div decomposition for op {op}") + raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}") class DecomposeAvgPool2d(ExportPass): - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops): diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py index 82937241369..b18bd4d9ac8 100644 --- a/backends/arm/_passes/decompose_batch_norm_no_stats.py +++ b/backends/arm/_passes/decompose_batch_norm_no_stats.py @@ -11,6 +11,9 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT + +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -34,7 +37,10 @@ class DecomposeBatchNormNoStatsPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOT, + InsertTableOpsPass, + } def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 bn_ops = ( diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index b71ca388651..cbfbd5783e2 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -6,6 +6,12 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorArgPassTOSAMI, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -22,7 +28,12 @@ class DecomposeCoshPass(ArmPass): """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + ReplaceScalarWithTensorArgPassTOSAMI, + MatchArgDtypePass, + } def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_cosh: diff --git a/backends/arm/_passes/decompose_cosine_similarity_pass.py b/backends/arm/_passes/decompose_cosine_similarity_pass.py index e2ab01b345f..965dad54697 100644 --- a/backends/arm/_passes/decompose_cosine_similarity_pass.py +++ b/backends/arm/_passes/decompose_cosine_similarity_pass.py @@ -6,6 +6,13 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( + ConvertFullLikeToFullPass, +) + +from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.pass_base import ExportPass torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,) @@ -24,7 +31,12 @@ class DecomposeCosineSimilarityPass(ExportPass): out = div(dot, denom) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeDivPass, + DecomposeSumPass, + ConvertFullLikeToFullPass, + InsertTableOpsPass, + } def call_operator(self, op, args, kwargs, meta): if op not in torch_cosine_similarity: diff --git a/backends/arm/_passes/decompose_cumsum_pass.py b/backends/arm/_passes/decompose_cumsum_pass.py index 04e6275c6c1..32c59f6d793 100644 --- a/backends/arm/_passes/decompose_cumsum_pass.py +++ b/backends/arm/_passes/decompose_cumsum_pass.py @@ -8,6 +8,7 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.add_bias_pass import AddBiasPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs @@ -40,7 +41,7 @@ class DecomposeCumsumPass(ArmPass): And the convolution is applied over dimension H. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {AddBiasPass} def call(self, graph_module): graph = graph_module.graph diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index b6e289ff049..b6db103930e 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -9,6 +9,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -39,7 +40,7 @@ class DecomposeDivPass(ExportPass): y = mul(a,x) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops): diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index 0e6b40afbb2..b5352475d51 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -5,7 +5,10 @@ # pyre-unsafe +from typing import Set, Type + import torch +from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -48,6 +51,8 @@ class DecomposeDivTensorModePass(ExportPass): rounding_mode='trunc' -> where(div(a,b) < 0, ceil(div(a,b)), floor(div(a,b))) """ + _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivPass} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_mode_ops + aten_div_mode_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_embedding_pass.py b/backends/arm/_passes/decompose_embedding_pass.py index 5b2ad27eaf6..01226a7a38e 100644 --- a/backends/arm/_passes/decompose_embedding_pass.py +++ b/backends/arm/_passes/decompose_embedding_pass.py @@ -11,6 +11,7 @@ from typing import Set, Type import torch +from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -34,7 +35,7 @@ class DecomposeEmbeddingPass(ExportPass): i = indices is expected to be int32 before this pass """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransform} aten_ops = (torch.ops.aten.embedding.default,) edge_ops = (exir_ops.edge.aten.embedding.default,) diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 21d3c975de3..5de03cbf102 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -6,6 +6,14 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_int_pow_to_mul import ConvertIntPowToMuls +from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorArgPassTOSAMI, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -71,7 +79,14 @@ class DecomposeExpm1Pass(ArmPass): - exir_ops.edge.aten.logical_and.default """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + ConvertIntPowToMuls, + InsertTableOpsPass, + DecomposeDivPass, + ReplaceScalarWithTensorArgPassTOSAMI, + MatchArgDtypePass, + MatchArgRanksPass, + } def call_operator(self, op, args, kwargs, meta): if op not in edge_expm1_ops: diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index ef6a4753b8c..237b8199e82 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -7,6 +7,10 @@ import torch from executorch.backends.arm._passes.arm_pass_utils import get_node_arg +from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -79,7 +83,12 @@ class DecomposeGeluPass(ExportPass): %op7 = mul(%op6, %FULL_0_5) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOT, + InsertTableOpsPass, + MatchArgDtypePass, + MatchArgRanksPass, + } def call_operator(self, op, args, kwargs, meta): if op not in torch_gelu + edge_gelu: diff --git a/backends/arm/_passes/decompose_glu_pass.py b/backends/arm/_passes/decompose_glu_pass.py index 6b53609c951..373b31c5995 100644 --- a/backends/arm/_passes/decompose_glu_pass.py +++ b/backends/arm/_passes/decompose_glu_pass.py @@ -7,6 +7,7 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -39,7 +40,7 @@ def get_ops(op): class DecomposeGluPass(ArmPass): """Decomposes the GLU operator into hadamard product and sigmoid.""" - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} def call_operator(self, op, args, kwargs, meta): if op not in [edge_glu, aten_glu]: diff --git a/backends/arm/_passes/decompose_grouped_conv.py b/backends/arm/_passes/decompose_grouped_conv.py index 2f0d7b4d72c..916e43ee9a4 100644 --- a/backends/arm/_passes/decompose_grouped_conv.py +++ b/backends/arm/_passes/decompose_grouped_conv.py @@ -7,6 +7,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -34,7 +35,7 @@ class DecomposeGroupedConv(ExportPass): x = cat(x1, x2) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {Conv1dUnsqueezePass} @staticmethod def _get_decomposition(op): diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index 7f0d7fdeafd..29d68234b29 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -11,6 +11,10 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass +from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -58,7 +62,12 @@ class DecomposeGroupNormPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + DecomposeMeanDimPass, + DecomposeVarPass, + SizeAdjustInputPass, + } def call(self, graph_module: torch.fx.GraphModule): modified = False diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 0710ed37b45..c73806b0022 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -11,6 +11,10 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass +from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -57,7 +61,12 @@ class DecomposeLayerNormPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOT, + DecomposeMeanDimPass, + DecomposeVarPass, + InsertTableOpsPass, + } def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 17441981654..ea5dd2d9b55 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -6,6 +6,8 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.exir.pass_base import ExportPass @@ -30,7 +32,10 @@ class DecomposeLinearVectorNormPass(ExportPass): dtype prior, but we dont know this from FX graph. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSqrtPass, + DecomposeSumPass, + } torch_linalg_vector_norm = (torch.ops.aten.linalg_vector_norm.default,) diff --git a/backends/arm/_passes/decompose_logit_pass.py b/backends/arm/_passes/decompose_logit_pass.py index a82650f0b9e..213b8f038e8 100644 --- a/backends/arm/_passes/decompose_logit_pass.py +++ b/backends/arm/_passes/decompose_logit_pass.py @@ -8,6 +8,12 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorArgPassTOSAMI, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -63,7 +69,12 @@ class DecomposeLogitPass(ArmPass): log(y * reciprocal((-1) * y + 1)) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + MatchArgDtypePass, + ReplaceScalarWithTensorArgPassTOSAMI, + } def call_operator(self, op, args, kwargs, meta): if op not in [edge_logit, aten_logit]: diff --git a/backends/arm/_passes/decompose_masked_fill.py b/backends/arm/_passes/decompose_masked_fill.py index ced58aa3920..8c41c1a11bc 100644 --- a/backends/arm/_passes/decompose_masked_fill.py +++ b/backends/arm/_passes/decompose_masked_fill.py @@ -11,6 +11,9 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( + ConvertFullLikeToFullPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -40,7 +43,7 @@ class DecomposeMaskedFill(ArmPass): Decomposed to a where and a full_like operator. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops): diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py index 1df062ddb57..22d2ec1d85b 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py @@ -9,6 +9,7 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -24,7 +25,9 @@ class DecomposeMaxPool2DPass(ArmPass): Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + SizeAdjustInputPass, + } def call_operator(self, op, args, kwargs, meta): # Only intercept EXIR edge max_pool2d ops diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 716924dfbf2..e3e0a873020 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -10,6 +10,9 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -64,7 +67,11 @@ class DecomposeMeanDimPass(ArmPass): x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOT, + DecomposeSumPass, + SizeAdjustInputPass, + } def __init__(self, graph_module, tosa_spec): super().__init__() diff --git a/backends/arm/_passes/decompose_select.py b/backends/arm/_passes/decompose_select.py index 9c65cd1c0a8..049409af6fd 100644 --- a/backends/arm/_passes/decompose_select.py +++ b/backends/arm/_passes/decompose_select.py @@ -13,6 +13,9 @@ create_node, get_first_fake_tensor, ) +from executorch.backends.arm._passes.convert_squeezes_to_view import ( + ConvertSqueezesToViewPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -22,7 +25,7 @@ class DecomposeSelectPass(ExportPass): This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass} def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/decompose_silu_pass.py b/backends/arm/_passes/decompose_silu_pass.py index cb7b55be520..3d31552cf35 100644 --- a/backends/arm/_passes/decompose_silu_pass.py +++ b/backends/arm/_passes/decompose_silu_pass.py @@ -8,6 +8,7 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.pass_base import ExportPass aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default) @@ -24,7 +25,7 @@ class DecomposeSiluPass(ExportPass): y = mul(a,x) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} def call_operator(self, op, args, kwargs, meta): if op not in (aten_silu_ops): diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index 473a263e9a5..acb18df3134 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -7,6 +7,12 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorArgPassTOSAMI, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -27,7 +33,12 @@ class DecomposeSinhPass(ArmPass): and scalar multiplication. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + ReplaceScalarWithTensorArgPassTOSAMI, + MatchArgDtypePass, + } def call_operator(self, op, args, kwargs, meta): if op is not edge_sinh: diff --git a/backends/arm/_passes/decompose_softmax_pass.py b/backends/arm/_passes/decompose_softmax_pass.py index 47f448ae851..52df7cf6700 100644 --- a/backends/arm/_passes/decompose_softmax_pass.py +++ b/backends/arm/_passes/decompose_softmax_pass.py @@ -6,6 +6,8 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -64,7 +66,10 @@ class DecomposeSoftmaxPass(ExportPass): (in logsoftmax case: %op7 = log(%op6)) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSumPass, + InsertTableOpsPass, + } def call_operator(self, op, args, kwargs, meta): if op not in torch_softmax + edge_softmax: diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py index 5e704585eb0..04e99a46b3e 100644 --- a/backends/arm/_passes/decompose_softmax_unstable_pass.py +++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py @@ -9,6 +9,8 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -60,7 +62,10 @@ class DecomposeSoftmaxUnstablePass(ArmPass): (in logsoftmax case: %op5 = log(%op4)) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSumPass, + InsertTableOpsPass, + } def call_operator(self, op, args, kwargs, meta): if op not in torch_softmax + edge_softmax: diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index c93686901d5..3f4e608c4b9 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -7,6 +7,7 @@ from typing import Set, Tuple, Type, Union import torch +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -27,7 +28,7 @@ def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]: class DecomposeSqrtPass(ExportPass): - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} def call_operator(self, op, args, kwargs, meta): """ diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index f8396da0420..db5d820ac70 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -12,6 +12,9 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg +from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -50,7 +53,11 @@ class DecomposeVarPass(ArmPass): y = div(sum, max(0, N-correction)) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOT, + DecomposeMeanDimPass, + DecomposeSumPass, + } def call_operator(self, op, args, kwargs, meta): if op not in ( diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 714543d3908..477e007b8bf 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -15,8 +15,10 @@ get_param_tensor, is_param_node, ) +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.quant_args import QuantArgs +from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops @@ -70,6 +72,44 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]: return output_qparams +class RetraceFoldedDtypesPass(ExportPass): + """ + FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced + some operators are retraced to types that cannot be handled by TOSA. One + such example is sum.dim_IntList: + q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... + After folding it becomes: + q (int8) -> sum (int64) -> ... + This pass changes types of ops in self.targeted_ops, such as sum, so that + the output type of that matches the type of the output_qparams. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + targeted_ops: Set[EdgeOpOverload] = { + exir_ops.edge.aten.sum.dim_IntList, + } + + def call_operator( + self, + op, # pyre-ignore + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in self.targeted_ops: + return super().call_operator(op, args, kwargs, meta) + + node_kwargs = kwargs.copy() + output_qparams = meta["output_qparams"] + if len(output_qparams) == 0: + return super().call_operator(op, args, kwargs, meta) + + output_dtype = output_qparams[0].dtype + node_kwargs["dtype"] = output_dtype + return super().call_operator(op, args, node_kwargs, meta) + + class FoldAndAnnotateQParamsPass(ArmPass): """ A pass that walks the graph and removes any DQ and Q nodes before and after the target @@ -100,7 +140,11 @@ class FoldAndAnnotateQParamsPass(ArmPass): """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + RetraceFoldedDtypesPass, + InsertTableOpsPass, + RemoveNoopPass, + } def fold_and_annotate_arg( self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int @@ -212,7 +256,7 @@ class QuantizeOperatorArguments(ExportPass): - Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {FoldAndAnnotateQParamsPass} def call(self, graph_module: GraphModule) -> PassResult: modified = False @@ -247,41 +291,3 @@ def call(self, graph_module: GraphModule) -> PassResult: modified = True return PassResult(graph_module, modified) - - -class RetraceFoldedDtypesPass(ExportPass): - """ - FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced - some operators are retraced to types that cannot be handled by TOSA. One - such example is sum.dim_IntList: - q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... - After folding it becomes: - q (int8) -> sum (int64) -> ... - This pass changes types of ops in self.targeted_ops, such as sum, so that - the output type of that matches the type of the output_qparams. - """ - - _passes_required_after: Set[Type[ExportPass]] = set() - - targeted_ops: Set[EdgeOpOverload] = { - exir_ops.edge.aten.sum.dim_IntList, - } - - def call_operator( - self, - op, # pyre-ignore - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in self.targeted_ops: - return super().call_operator(op, args, kwargs, meta) - - node_kwargs = kwargs.copy() - output_qparams = meta["output_qparams"] - if len(output_qparams) == 0: - return super().call_operator(op, args, kwargs, meta) - - output_dtype = output_qparams[0].dtype - node_kwargs["dtype"] = output_dtype - return super().call_operator(op, args, node_kwargs, meta) diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 07f3a4af245..c7afe2af151 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -14,6 +14,9 @@ get_param_tensor, is_persistent_buffer, ) +from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( + FuseEqualPlaceholdersPass, +) from executorch.backends.transforms.utils import ( create_constant_placeholder, delete_constant_placeholder, @@ -171,7 +174,7 @@ def f(node_name_pre_computed): return node_name_pre_computed """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {FuseEqualPlaceholdersPass} targeted_ops = [ exir_ops.edge.aten.full.default, diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index d39d7135f9c..1076a3df658 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -8,15 +8,24 @@ from typing import Set, Type import torch +from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import Q_OPS +from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import Node class FuseQuantizedActivationPass(ExportPass): - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + ConvertToClampPass, + FoldAndAnnotateQParamsPass, + RemoveGetItemPass, + } @staticmethod def _is_fuseable_quantized_activation(node: Node): diff --git a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py index 4b619af790c..c6e6f70a630 100644 --- a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py +++ b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py @@ -8,8 +8,13 @@ import logging +from typing import Set, Type + import torch from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.decompose_embedding_pass import ( + DecomposeEmbeddingPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import EdgeOpOverload, ExportPass, PassResult from torch._subclasses.fake_tensor import FakeTensor @@ -26,6 +31,8 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass): the int32 range. """ + _passes_required_after: Set[Type[ExportPass]] = {DecomposeEmbeddingPass} + # Ops that require i64 inputs → positions of args to upcast. # Key: op overload; Value: zero-based indices of positional args that must be i64. I64_INPUT_ARG_POSITIONS = { diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 6be0b9e2ac4..c6f4786365d 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -14,6 +14,12 @@ get_first_fake_tensor, insert_q_dq_pair, ) +from executorch.backends.arm._passes.convert_squeezes_to_view import ( + ConvertSqueezesToViewPass, +) +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -30,7 +36,10 @@ class ConvertMmToBmmPass(ExportPass): 3) Squeeze output tensor to rank 2. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = { + ConvertSqueezesToViewPass, + FoldAndAnnotateQParamsPass, + } def call(self, graph_module: torch.fx.GraphModule): modified_graph = False diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index bb2a02cc679..9ad3e318011 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -10,6 +10,7 @@ import torch from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node @@ -22,7 +23,7 @@ class ScalarsToAttributePass(ExportPass): to attribute Nodes that output the same value. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {MatchArgRanksPass} targeted_ops = [ torch.ops.aten.add.Tensor,