From 89fd30aaf1746129d680cfd0dc4c194cd5931cd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Thu, 13 Nov 2025 17:13:32 +0100 Subject: [PATCH] Arm backend: Rename passes for consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add *Pass suffix for passes missing it * Add missing _pass suffixes to python files containing passes * Correct pass name: DecomposeLinearVectorNormPass to DecomposeLinalgVectorNormPass * Rename ConvertIntPowToMuls to DecomposeIntPowPass * Rename QuantizeOperatorArguments to QuantizeClampArgumentsPass Signed-off-by: Martin Lindström Change-Id: Ied2bdce1a5240464db25a0b42eb9cee8a078f73b --- backends/arm/_passes/__init__.py | 27 ++++++----- backends/arm/_passes/arm_pass_manager.py | 48 +++++++++---------- .../_passes/convert_full_like_to_full_pass.py | 6 ++- .../convert_int64_const_ops_to_int32.py | 11 +++-- ...t_to_clamp.py => convert_to_clamp_pass.py} | 4 +- .../decompose_adaptive_avg_pool2d_pass.py | 6 ++- ...pool2d.py => decompose_avg_pool2d_pass.py} | 8 ++-- .../_passes/decompose_batch_norm_no_stats.py | 6 ++- backends/arm/_passes/decompose_expm1_pass.py | 4 +- backends/arm/_passes/decompose_gelu_pass.py | 6 ++- ...conv.py => decompose_grouped_conv_pass.py} | 10 ++-- ...ow_to_mul.py => decompose_int_pow_pass.py} | 2 +- .../arm/_passes/decompose_layernorm_pass.py | 6 ++- .../decompose_linalg_vector_norm_pass.py | 2 +- ..._fill.py => decompose_masked_fill_pass.py} | 2 +- ...decompose_maxpool2d_with_dilation_pass.py} | 2 +- .../arm/_passes/decompose_meandim_pass.py | 6 ++- backends/arm/_passes/decompose_sdpa_pass.py | 2 +- backends/arm/_passes/decompose_var_pass.py | 6 ++- .../fold_qdq_with_annotated_qparams_pass.py | 2 +- ...rm2d_pass.py => fuse_batch_norm2d_pass.py} | 2 +- .../arm/_passes/fuse_constant_ops_pass.py | 2 +- .../_passes/fuse_quantized_activation_pass.py | 2 +- .../arm/_passes/replace_inf_values_pass.py | 2 +- .../arm/_passes/size_adjust_input_pass.py | 6 +-- .../tosa_supported_operators.py | 10 ++-- backends/arm/operators/op_clamp.py | 2 +- .../arm/test/passes/test_convert_to_clamp.py | 2 +- .../passes/test_decompose_avg_pool2d_pass.py | 6 ++- ...muls.py => test_decompose_int_pow_pass.py} | 6 +-- .../test_decompose_linalg_vector_norm_pass.py | 6 +-- .../test/passes/test_fuse_batchnorm_pass.py | 4 +- .../passes/test_fuse_constant_ops_pass.py | 17 +++++-- 33 files changed, 136 insertions(+), 97 deletions(-) rename backends/arm/_passes/{convert_to_clamp.py => convert_to_clamp_pass.py} (91%) rename backends/arm/_passes/{decompose_avg_pool2d.py => decompose_avg_pool2d_pass.py} (98%) rename backends/arm/_passes/{decompose_grouped_conv.py => decompose_grouped_conv_pass.py} (95%) rename backends/arm/_passes/{convert_int_pow_to_mul.py => decompose_int_pow_pass.py} (98%) rename backends/arm/_passes/{decompose_masked_fill.py => decompose_masked_fill_pass.py} (97%) rename backends/arm/_passes/{decompose_maxpool2d_with_dilation.py => decompose_maxpool2d_with_dilation_pass.py} (99%) rename backends/arm/_passes/{fuse_batchnorm2d_pass.py => fuse_batch_norm2d_pass.py} (99%) rename backends/arm/test/passes/{test_convert_int_pow_to_muls.py => test_decompose_int_pow_pass.py} (92%) diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index cc4541120a5..662edc5811f 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -18,14 +18,13 @@ from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa -from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa from .convert_minmax_pass import ConvertMinMaxPass # noqa from .convert_permute_singleton_to_view_pass import ( # noqa ConvertPermuteSingletonToViewPass, ) from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa -from .convert_to_clamp import ConvertToClampPass # noqa +from .convert_to_clamp_pass import ConvertToClampPass # noqa from .decompose_acosh_pass import DecomposeAcoshPass # noqa from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa @@ -35,7 +34,7 @@ from .decompose_asinh_pass import DecomposeAsinhPass # noqa from .decompose_atan_pass import DecomposeAtanPass # noqa from .decompose_atanh_pass import DecomposeAtanhPass # noqa -from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa +from .decompose_avg_pool2d_pass import DecomposeAvgPool2dPass # noqa from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa from .decompose_cosh_pass import DecomposeCoshPass # noqa from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa @@ -48,23 +47,24 @@ from .decompose_floor_divide_pass import DecomposeFloorDividePass # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa from .decompose_glu_pass import DecomposeGluPass # noqa -from .decompose_grouped_conv import DecomposeGroupedConv # noqa +from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa from .decompose_int16_activation_conv2d_pass import ( # noqa DecomposeConv2dWithInt16ActivationPass, ) +from .decompose_int_pow_pass import DecomposeIntPowPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa -from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa +from .decompose_linalg_vector_norm_pass import DecomposeLinalgVectorNormPass # noqa from .decompose_linear_pass import DecomposeLinearPass # noqa from .decompose_logit_pass import DecomposeLogitPass # noqa -from .decompose_masked_fill import DecomposeMaskedFill # noqa -from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa +from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa +from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_ne_pass import DecomposeNotEqualPass # noqa from .decompose_remainder_pass import DecomposeRemainderPass # noqa from .decompose_round_pass import DecomposeRoundPass # noqa -from .decompose_sdpa_pass import DecomposeScaledDotProductAttention # noqa +from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa from .decompose_select import DecomposeSelectPass # noqa from .decompose_sign_pass import DecomposeSignPass # noqa from .decompose_silu_pass import DecomposeSiluPass # noqa @@ -77,10 +77,13 @@ from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, - QuantizeOperatorArguments, + QuantizeClampArgumentsPass, +) +from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa +from .fuse_constant_ops_pass import ( # noqa + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, ) -from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa -from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa @@ -107,5 +110,5 @@ from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa -from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip +from .replace_inf_values_pass import ReplaceInfValuesPass # noqa # usort: skip from .arm_pass_manager import ArmPassManager # noqa # usort: skip diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 2ae84802912..53921790391 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -17,14 +17,13 @@ CastBoolToInt8Pass, CastInt64BuffersToInt32Pass, CastToInt32Pass, - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, Conv1dUnsqueezePass, ConvertELUParamsPass, ConvertExpandCopyToRepeatPass, ConvertFullLikeToFullPass, ConvertInt64ConstOpsToInt32Pass, ConvertInt64OutputOpsToInt32Pass, - ConvertIntPowToMuls, ConvertMinMaxPass, ConvertMmToBmmPass, ConvertPermuteSingletonToViewPass, @@ -40,7 +39,7 @@ DecomposeAsinhPass, DecomposeAtanhPass, DecomposeAtanPass, - DecomposeAvgPool2d, + DecomposeAvgPool2dPass, DecomposeBatchNormNoStatsPass, DecomposeConv2dWithInt16ActivationPass, DecomposeCoshPass, @@ -54,20 +53,21 @@ DecomposeFloorDividePass, DecomposeGeluPass, DecomposeGluPass, - DecomposeGroupedConv, + DecomposeGroupedConvPass, DecomposeGroupNormPass, + DecomposeIntPowPass, DecomposeLayerNormPass, DecomposeLeakyReLUPass, + DecomposeLinalgVectorNormPass, DecomposeLinearPass, - DecomposeLinearVectorNormPass, DecomposeLogitPass, - DecomposeMaskedFill, - DecomposeMaxPool2DPass, + DecomposeMaskedFillPass, + DecomposeMaxPool2dPass, DecomposeMeanDimPass, DecomposeNotEqualPass, DecomposeRemainderPass, DecomposeRoundPass, - DecomposeScaledDotProductAttention, + DecomposeScaledDotProductAttentionPass, DecomposeSelectPass, DecomposeSignPass, DecomposeSiluPass, @@ -79,7 +79,7 @@ DecomposeVarPass, DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, - FuseBatchnorm2DPass, + FuseBatchNorm2dPass, FuseConstantArgsPass, FuseDuplicateUsersPass, FuseEqualPlaceholdersPass, @@ -91,11 +91,11 @@ InsertTableOpsPass, MatchArgDtypePass, MatchArgRanksPass, - QuantizeOperatorArguments, + QuantizeClampArgumentsPass, RemoveGetItemPass, RemoveGraphAssertsPass, RemoveNoopPass, - ReplaceInfValues, + ReplaceInfValuesPass, ReplaceScalarWithTensorByProfilePass, RewriteConv2dPass, RewriteMatmulPass, @@ -181,7 +181,7 @@ def _tosa_pipeline( AnnotateDecomposedMatmulPass(), ConvertELUParamsPass(), ConvertSplitToSlicePass(), - QuantizeOperatorArguments(), + QuantizeClampArgumentsPass(), ] ) @@ -202,7 +202,7 @@ def _tosa_pipeline( self.add_passes( [ DecomposeLogitPass(), - DecomposeMaskedFill(), + DecomposeMaskedFillPass(), DecomposeRoundPass(), DecomposeAcoshPass(), DecomposeAsinhPass(), @@ -214,14 +214,14 @@ def _tosa_pipeline( DecomposeAddmmPass(), DecomposeEluPass(), DecomposeExpm1Pass(), - ConvertIntPowToMuls(), + DecomposeIntPowPass(), CastBoolToInt8Pass(), DecomposeSinhPass(), DecomposeSignPass(), DecomposeFloorDividePass(), DecomposeGeluPass(), DecomposeAddSubAlphaPass(), - DecomposeGroupedConv(), + DecomposeGroupedConvPass(), Conv1dUnsqueezePass(), ] ) @@ -247,7 +247,7 @@ def _tosa_pipeline( DecomposeRemainderPass(), DecomposeDivTensorModePass(), DecomposeEmbeddingPass(), - FuseBatchnorm2DPass(exported_program), + FuseBatchNorm2dPass(exported_program), ConvertMmToBmmPass(), DecomposeGluPass(), DecomposeLeakyReLUPass(), @@ -256,13 +256,13 @@ def _tosa_pipeline( ConvertMinMaxPass(), DecomposeAnyPass(), DecomposeAdaptiveAvgPool2dPass(), - DecomposeAvgPool2d(), + DecomposeAvgPool2dPass(), DecorateFp32toInt32CastingPass(), - ComputeConstantOpsAOT(exported_program), + ComputeConstantOpsAOTPass(exported_program), ConvertExpandCopyToRepeatPass(), UnsqueezeBeforeRepeatPass(), DecomposeCumsumPass(exported_program), - DecomposeMaxPool2DPass(), + DecomposeMaxPool2dPass(), SizeAdjustInputPass(), DecomposeSelectPass(), ConvertSqueezesToViewPass(), @@ -324,7 +324,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): ConvertInt64OutputOpsToInt32Pass(), InsertInt32CastsAfterInt64PlaceholdersPass(), DecomposeEmbeddingPass(), - DecomposeScaledDotProductAttention(), + DecomposeScaledDotProductAttentionPass(), DecomposeRoundPass(), DecomposeLogitPass(), CastBoolToInt8Pass(), @@ -357,10 +357,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): DecomposeGluPass(), DecomposeDivPass(), DecomposeLeakyReLUPass(), - DecomposeLinearVectorNormPass(), + DecomposeLinalgVectorNormPass(), DecomposeSqrtPass(), DecomposeSiluPass(), - DecomposeAvgPool2d(), + DecomposeAvgPool2dPass(), ( DecomposeSoftmaxUnstablePass() if self.tosa_spec.is_U55_subset @@ -373,8 +373,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): # Postprocessing passes self.add_passes( [ - ReplaceInfValues(), - DecomposeMaskedFill() if not self.tosa_spec.is_U55_subset else None, + ReplaceInfValuesPass(), + DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None, ] ) diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index 06822a4abcf..becb0b7f971 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -6,7 +6,9 @@ from typing import Set, Type from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -24,7 +26,7 @@ class ConvertFullLikeToFullPass(ArmPass): Skip layout and device since it's not relevant for our backend. """ - _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} def call_operator(self, op, args, kwargs, meta): if op not in [ diff --git a/backends/arm/_passes/convert_int64_const_ops_to_int32.py b/backends/arm/_passes/convert_int64_const_ops_to_int32.py index dff270fda13..85fcf715f07 100644 --- a/backends/arm/_passes/convert_int64_const_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_const_ops_to_int32.py @@ -9,7 +9,9 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.exir.pass_base import ExportPass, PassResult @@ -30,7 +32,7 @@ class ConvertInt64ConstOpsToInt32Pass(ArmPass): 5. `torch.tensor` """ - _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} torch_ops = [ torch.ops.aten.full.default, @@ -47,7 +49,10 @@ def call(self, graph_module: torch.fx.GraphModule): if node.op != "call_function": continue - if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops: + if ( + node.target + not in ComputeConstantOpsAOTPass.targeted_ops + self.torch_ops + ): continue data = node.target(*node.args, **node.kwargs) diff --git a/backends/arm/_passes/convert_to_clamp.py b/backends/arm/_passes/convert_to_clamp_pass.py similarity index 91% rename from backends/arm/_passes/convert_to_clamp.py rename to backends/arm/_passes/convert_to_clamp_pass.py index 1ada1efe69b..4b28f993acd 100644 --- a/backends/arm/_passes/convert_to_clamp.py +++ b/backends/arm/_passes/convert_to_clamp_pass.py @@ -8,7 +8,7 @@ from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - QuantizeOperatorArguments, + QuantizeClampArgumentsPass, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -30,7 +30,7 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]: class ConvertToClampPass(ArmPass): - _passes_required_after: Set[Type[ExportPass]] = {QuantizeOperatorArguments} + _passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass} def call_operator(self, op, args, kwargs, meta): if op not in edge_operators: diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index 9f9aaa65271..5905e8f4496 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -9,7 +9,9 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d +from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( + DecomposeAvgPool2dPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata @@ -44,7 +46,7 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass): The output is of size output_size_h x output_size_w for any input. """ - _passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2d} + _passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass} def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops): diff --git a/backends/arm/_passes/decompose_avg_pool2d.py b/backends/arm/_passes/decompose_avg_pool2d_pass.py similarity index 98% rename from backends/arm/_passes/decompose_avg_pool2d.py rename to backends/arm/_passes/decompose_avg_pool2d_pass.py index 0187ee45a1e..14b03cf6243 100644 --- a/backends/arm/_passes/decompose_avg_pool2d.py +++ b/backends/arm/_passes/decompose_avg_pool2d_pass.py @@ -8,7 +8,9 @@ import torch from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, ) @@ -37,8 +39,8 @@ def get_decomposition(op) -> tuple: raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}") -class DecomposeAvgPool2d(ArmPass): - _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} +class DecomposeAvgPool2dPass(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops): diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py index ef9b9f859cd..9a486376617 100644 --- a/backends/arm/_passes/decompose_batch_norm_no_stats.py +++ b/backends/arm/_passes/decompose_batch_norm_no_stats.py @@ -10,7 +10,9 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops @@ -37,7 +39,7 @@ class DecomposeBatchNormNoStatsPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, InsertTableOpsPass, } diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 09a891c34dc..d2eb908e925 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -6,8 +6,8 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.convert_int_pow_to_mul import ConvertIntPowToMuls from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.decompose_int_pow_pass import DecomposeIntPowPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -80,7 +80,7 @@ class DecomposeExpm1Pass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ConvertIntPowToMuls, + DecomposeIntPowPass, InsertTableOpsPass, DecomposeDivPass, ReplaceScalarWithTensorByProfilePass, diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 2a25e6dbb6d..5bf39370835 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -8,7 +8,9 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -85,7 +87,7 @@ class DecomposeGeluPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, InsertTableOpsPass, MatchArgDtypePass, MatchArgRanksPass, diff --git a/backends/arm/_passes/decompose_grouped_conv.py b/backends/arm/_passes/decompose_grouped_conv_pass.py similarity index 95% rename from backends/arm/_passes/decompose_grouped_conv.py rename to backends/arm/_passes/decompose_grouped_conv_pass.py index 11d9f605127..a0765b865fc 100644 --- a/backends/arm/_passes/decompose_grouped_conv.py +++ b/backends/arm/_passes/decompose_grouped_conv_pass.py @@ -14,7 +14,7 @@ from executorch.exir.pass_base import ExportPass -class DecomposeGroupedConv(ArmPass): +class DecomposeGroupedConvPass(ArmPass): """ Splits a grouped convolution which is not supported by TOSA into multiple convolutions using slice->conv->cat. @@ -81,7 +81,7 @@ def _get_meta_copy(meta, i, output_slice_size): new_qparams = meta.data.get("input_qparams").copy() # Get quantization params of the weights and slice them. qarg = new_qparams[1] - new_qparams[1] = DecomposeGroupedConv._split_per_channel_qparams( + new_qparams[1] = DecomposeGroupedConvPass._split_per_channel_qparams( qarg, index=i, output_slice_size=output_slice_size ) @@ -117,7 +117,7 @@ def call_operator(self, op, args, kwargs, meta): no_q_dq_meta.data = {} no_q_dq_meta.data = {} - slice_op, conv_op, cat_op = DecomposeGroupedConv._get_decomposition(op) + slice_op, conv_op, cat_op = DecomposeGroupedConvPass._get_decomposition(op) input_slices = [] for i in range(groups): @@ -163,7 +163,9 @@ def call_operator(self, op, args, kwargs, meta): zip(input_slices, filter_slices, bias_slices) ): - meta_copy = DecomposeGroupedConv._get_meta_copy(meta, i, output_slice_size) + meta_copy = DecomposeGroupedConvPass._get_meta_copy( + meta, i, output_slice_size + ) if op == exir_ops.edge.aten.convolution.default: conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1) diff --git a/backends/arm/_passes/convert_int_pow_to_mul.py b/backends/arm/_passes/decompose_int_pow_pass.py similarity index 98% rename from backends/arm/_passes/convert_int_pow_to_mul.py rename to backends/arm/_passes/decompose_int_pow_pass.py index e2c8bd0c4d6..4db5e45c120 100644 --- a/backends/arm/_passes/convert_int_pow_to_mul.py +++ b/backends/arm/_passes/decompose_int_pow_pass.py @@ -11,7 +11,7 @@ from executorch.exir.pass_base import ExportPass -class ConvertIntPowToMuls(ArmPass): +class DecomposeIntPowPass(ArmPass): """ Replaces pow with integer exponent with a series of multiplications. Only handles pow.Tensor_Scalar and not pow.Tensor_Tensor. diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 7623e410cf9..5f56de92512 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -12,7 +12,9 @@ from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -61,7 +63,7 @@ class DecomposeLayerNormPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, DecomposeMeanDimPass, DecomposeVarPass, InsertTableOpsPass, diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 5c6c8fc0ec5..83bbc6669ef 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -12,7 +12,7 @@ from executorch.exir.pass_base import ExportPass -class DecomposeLinearVectorNormPass(ArmPass): +class DecomposeLinalgVectorNormPass(ArmPass): """ This pass decomposes aten.linalg_vector_norm.default into more primitive ops. We need to add this pass before quantization for graph annotation. diff --git a/backends/arm/_passes/decompose_masked_fill.py b/backends/arm/_passes/decompose_masked_fill_pass.py similarity index 97% rename from backends/arm/_passes/decompose_masked_fill.py rename to backends/arm/_passes/decompose_masked_fill_pass.py index 5a0f12348ec..09a3492a0c6 100644 --- a/backends/arm/_passes/decompose_masked_fill.py +++ b/backends/arm/_passes/decompose_masked_fill_pass.py @@ -34,7 +34,7 @@ def _get_decomposition(op) -> tuple: raise RuntimeError(f"Unable to get decomposition for op {op}") -class DecomposeMaskedFill(ArmPass): +class DecomposeMaskedFillPass(ArmPass): """ Masked fill takes in a boolean mask, a tensor and a scalar value. Fills the tensor with the scalar value according to the boolean mask. diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py similarity index 99% rename from backends/arm/_passes/decompose_maxpool2d_with_dilation.py rename to backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py index 60f1dd4a500..bf3f6afc418 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py @@ -19,7 +19,7 @@ ) -class DecomposeMaxPool2DPass(ArmPass): +class DecomposeMaxPool2dPass(ArmPass): """ Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space. """ diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 1360fc44f98..9bff06b4dfe 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -11,7 +11,9 @@ from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.backend.utils import WhyNoPartitionReporter @@ -78,7 +80,7 @@ class DecomposeMeanDimPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, DecomposeSumPass, SizeAdjustInputPass, } diff --git a/backends/arm/_passes/decompose_sdpa_pass.py b/backends/arm/_passes/decompose_sdpa_pass.py index f372ffe8680..566b43d5aa3 100644 --- a/backends/arm/_passes/decompose_sdpa_pass.py +++ b/backends/arm/_passes/decompose_sdpa_pass.py @@ -10,7 +10,7 @@ from executorch.exir.pass_base import ExportPass -class DecomposeScaledDotProductAttention( +class DecomposeScaledDotProductAttentionPass( ArmPass, decompose_sdpa.DecomposeScaledDotProductAttention ): _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index f5903d61135..bb2e2066a06 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -12,7 +12,9 @@ from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -52,7 +54,7 @@ class DecomposeVarPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, DecomposeMeanDimPass, DecomposeSumPass, } diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 2a0e889f87c..0bbf04e1463 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -218,7 +218,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 return PassResult(graph_module, True) -class QuantizeOperatorArguments(ArmPass): +class QuantizeClampArgumentsPass(ArmPass): """ This pass makes sure that the arguments to clamp.default are quantized correctly. More specifically, this pass: diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batch_norm2d_pass.py similarity index 99% rename from backends/arm/_passes/fuse_batchnorm2d_pass.py rename to backends/arm/_passes/fuse_batch_norm2d_pass.py index 250cac230d8..d9ae706f503 100644 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ b/backends/arm/_passes/fuse_batch_norm2d_pass.py @@ -26,7 +26,7 @@ from torch.nn.utils.fusion import fuse_conv_bn_weights -class FuseBatchnorm2DPass(ArmPass): +class FuseBatchNorm2dPass(ArmPass): """Fuses the pattern convolution -> batchnorm by updating the weights and bias of the convolution and removing the batchnorm. """ diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 2c8986114db..a574ef554ad 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -164,7 +164,7 @@ def call(self, graph_module): return PassResult(graph_module, True) -class ComputeConstantOpsAOT(ArmPass): +class ComputeConstantOpsAOTPass(ArmPass): """ Evaluates call_functions that produce constant tensor outputs and replaces them with placeholders. diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index f50216153a5..09e989cd3aa 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -8,7 +8,7 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass +from executorch.backends.arm._passes.convert_to_clamp_pass import ConvertToClampPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, ) diff --git a/backends/arm/_passes/replace_inf_values_pass.py b/backends/arm/_passes/replace_inf_values_pass.py index 7a42d08dd61..d1f58fe148c 100644 --- a/backends/arm/_passes/replace_inf_values_pass.py +++ b/backends/arm/_passes/replace_inf_values_pass.py @@ -14,7 +14,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class ReplaceInfValues(ArmPass): +class ReplaceInfValuesPass(ArmPass): """ Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values. """ diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index 529d5262b00..9460c8f199a 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -113,17 +113,17 @@ def is_valid_operator(node: torch.fx.Node) -> bool: dilation = node.args[4] if len(node.args) >= 5 else 1 ceil_mode = node.args[5] if len(node.args) >= 6 else False - # Dilation should be handled first by DecomposeMaxPool2DPass + # Dilation should be handled first by DecomposeMaxPool2dPass if isinstance(dilation, int): if dilation > 1: raise ValueError( - "Expected max_pool2d with dilation = 1, has DecomposeMaxPool2DPass been run?" + "Expected max_pool2d with dilation = 1, has DecomposeMaxPool2dPass been run?" ) else: dilation = cast(list, dilation) if dilation[0] > 1 or dilation[1] > 1: raise ValueError( - "Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2DPass been run?" + "Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2dPass been run?" ) # If using ceil mode for rounding, the input does not need adjusting diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index fa0b106c00d..728ef68c666 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -22,7 +22,9 @@ get_first_fake_tensor, is_submodule_node, ) -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( FuseQuantizedActivationPass, ) @@ -293,7 +295,7 @@ def _is_quantized(self, node: torch.fx.Node) -> bool: return True if node.target in ( exir_ops.edge.aten.full_like.default, - *ComputeConstantOpsAOT.targeted_ops, + *ComputeConstantOpsAOTPass.targeted_ops, ): # Special cases where nodes have been created in to_edge, e.g. # .Scalar operations that have been promoted .Tensor operations @@ -546,7 +548,7 @@ def is_node_supported( for output_node in node.users ) if ( - node.target in ComputeConstantOpsAOT.targeted_ops + node.target in ComputeConstantOpsAOTPass.targeted_ops and users_output_non_int64 ): if not self.inside_int32_bounds(node): @@ -588,7 +590,7 @@ def is_node_supported( continue # Constant operator if input_node.op == "call_function": - if input_node.target in ComputeConstantOpsAOT.targeted_ops: + if input_node.target in ComputeConstantOpsAOTPass.targeted_ops: # This is not perfect since the input_node can still be rejected by other checks but # this should cover the majority of cases. if self.is_node_supported({}, input_node): diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 76aa75cd9fd..badf76c9384 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -96,7 +96,7 @@ def define_node( ) node_input_dtype = node.meta["val"].dtype - # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments + # NOTE: Quantization of the min/max arguments is handled by QuantizeClampArgumentsPass min_val, max_val = self._get_min_max_arguments(node, node_input_dtype) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/test/passes/test_convert_to_clamp.py b/backends/arm/test/passes/test_convert_to_clamp.py index 5072af000b0..b54c177e52f 100644 --- a/backends/arm/test/passes/test_convert_to_clamp.py +++ b/backends/arm/test/passes/test_convert_to_clamp.py @@ -7,7 +7,7 @@ from typing import ClassVar, Dict, Tuple import torch -from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass +from executorch.backends.arm._passes.convert_to_clamp_pass import ConvertToClampPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline diff --git a/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py b/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py index 405c3d7ca8f..c4aebae2292 100644 --- a/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py +++ b/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py @@ -6,7 +6,9 @@ from typing import cast, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d +from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( + DecomposeAvgPool2dPass, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -75,6 +77,6 @@ def test_decompose_avg_pool2d_tosa_MI(module: ModuleWithInputs) -> None: # After decomposition, we should still see avg_pool2d (transformed) "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1, }, - pass_list=[DecomposeAvgPool2d], + pass_list=[DecomposeAvgPool2dPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_convert_int_pow_to_muls.py b/backends/arm/test/passes/test_decompose_int_pow_pass.py similarity index 92% rename from backends/arm/test/passes/test_convert_int_pow_to_muls.py rename to backends/arm/test/passes/test_decompose_int_pow_pass.py index bccde782f55..a9a74c633e1 100644 --- a/backends/arm/test/passes/test_convert_int_pow_to_muls.py +++ b/backends/arm/test/passes/test_decompose_int_pow_pass.py @@ -6,7 +6,7 @@ from typing import cast, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes import ConvertIntPowToMuls +from executorch.backends.arm._passes import DecomposeIntPowPass from executorch.backends.arm.test import common @@ -60,7 +60,7 @@ def get_inputs(self) -> input_t: @common.parametrize("data", test_data) -def test_convert_pow_to_muls(data: TestParam) -> None: +def test_decompose_int_pow(data: TestParam) -> None: module_with_inputs, nbr_muls = data module = cast(torch.nn.Module, module_with_inputs) pipeline = PassPipeline[input_t]( @@ -75,6 +75,6 @@ def test_convert_pow_to_muls(data: TestParam) -> None: "executorch_exir_dialects_edge__ops_aten_mul_Tensor": nbr_muls, }, ops_not_after_pass=["executorch_exir_dialects_edge__ops_pow_Tensor_Scalar"], - pass_list=[ConvertIntPowToMuls], + pass_list=[DecomposeIntPowPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py b/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py index bd83bfc9a22..b926e15b92a 100644 --- a/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py +++ b/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py @@ -8,7 +8,7 @@ import torch from executorch.backends.arm._passes.decompose_linalg_vector_norm_pass import ( - DecomposeLinearVectorNormPass, + DecomposeLinalgVectorNormPass, ) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -65,7 +65,7 @@ def get_inputs(self) -> input_t: @common.parametrize("module", modules) def test_decompose_vector_norm_tosa_INT(module: ModuleWithInputs) -> None: """ - This test creates a PassPipeline that applies the DecomposeLinearVectorNormPass. + This test creates a PassPipeline that applies the DecomposeLinalgVectorNormPass. The expected primitive ops vary depending on the norm order: - p == 1: should decompose to ABS and SUM. - p == 2 (default): should decompose to MUL, SUM, and SQRT. @@ -102,6 +102,6 @@ def test_decompose_vector_norm_tosa_INT(module: ModuleWithInputs) -> None: ops_not_after_pass=[ "executorch_exir_dialects_edge__ops_aten_linarg_vector_norm_default", ], - pass_list=[DecomposeLinearVectorNormPass], + pass_list=[DecomposeLinalgVectorNormPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py index 08bf960da7d..eb073265a63 100644 --- a/backends/arm/test/passes/test_fuse_batchnorm_pass.py +++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py @@ -6,7 +6,7 @@ from typing import cast, ClassVar, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass +from executorch.backends.arm._passes.fuse_batch_norm2d_pass import FuseBatchNorm2dPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -159,6 +159,6 @@ def test_fuse_batchnorm_tosa_FP(module: ModuleWithBatchNormAttrs) -> None: quantize=False, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, - passes_with_exported_program=[FuseBatchnorm2DPass], + passes_with_exported_program=[FuseBatchNorm2dPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 95492075c0d..deb017bf662 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -8,7 +8,7 @@ import torch from executorch.backends.arm._passes.fuse_constant_ops_pass import ( - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, FuseConstantArgsPass, ) from executorch.backends.arm.test import common @@ -157,7 +157,10 @@ def test_fuse_const_ops_tosa_FP(module: ModuleWithFuseAttrs) -> None: ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, ops_not_after_pass=module.ops_not_after_pass, - passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], + passes_with_exported_program=[ + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, + ], ) pipeline.run() @@ -170,7 +173,10 @@ def test_fuse_const_ops_tosa_INT(module: ModuleWithFuseAttrs) -> None: quantize=True, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, - passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], + passes_with_exported_program=[ + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, + ], ) pipeline.run() @@ -183,7 +189,10 @@ def test_fuse_const_ops_tosa_BI_cat(module: ModuleWithFuseAttrs) -> None: quantize=True, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, - passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], + passes_with_exported_program=[ + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, + ], ) pipeline.run()