From 2bbc88a69196506f2b9bd80100d595e0dfbbb3ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Wed, 5 Nov 2025 16:06:32 +0100 Subject: [PATCH] Arm backend: Add add_passes method MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a method, for ArmPassManager, called add_passes. This serves to group blocks of passes together more clearly for the reader. Signed-off-by: Martin Lindström Change-Id: I313f47314ee5b74164446c091859f6002e12989d --- backends/arm/_passes/arm_pass_manager.py | 322 ++++++++++-------- .../arm/_passes/remove_graph_asserts_pass.py | 6 +- 2 files changed, 181 insertions(+), 147 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index ce8015a18b5..2ae84802912 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -7,6 +7,7 @@ from collections import defaultdict +from collections.abc import Sequence import executorch.backends.arm.tosa.dialect # noqa: unused from executorch.backends.arm._passes import ( @@ -112,6 +113,7 @@ TosaSpecification, ) from executorch.exir import ExportedProgram +from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult @@ -150,6 +152,11 @@ def validate_constraints_mandatory(self): raise RuntimeError(error_msg) + def add_passes(self, passes: Sequence[ExportPass | None]): + for p in passes: + if p is not None: + self.add_pass(p) + def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module @@ -158,120 +165,136 @@ def _tosa_pipeline( self, exported_program: ExportedProgram, graph_module: GraphModule ) -> GraphModule: # Preprocessing passes - self.add_pass(AnnotateOutputDimOrderPass()) # Node transformation passes (pre q/dq folding) - - self.add_pass(FuseQuantizedActivationPass()) - self.add_pass(RemoveGetItemPass()) - self.add_pass(ConvertToClampPass()) - self.add_pass(DecomposeGroupNormPass()) - self.add_pass(DecomposeLayerNormPass()) - self.add_pass(DecomposeBatchNormNoStatsPass()) - self.add_pass(DecomposeVarPass()) - self.add_pass( - DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) + self.add_passes( + [ + FuseQuantizedActivationPass(), + RemoveGetItemPass(), + ConvertToClampPass(), + DecomposeGroupNormPass(), + DecomposeLayerNormPass(), + DecomposeBatchNormNoStatsPass(), + DecomposeVarPass(), + DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec), + AnnotateDecomposedMatmulPass(), + ConvertELUParamsPass(), + ConvertSplitToSlicePass(), + QuantizeOperatorArguments(), + ] ) - self.add_pass(AnnotateDecomposedMatmulPass()) - self.add_pass(ConvertELUParamsPass()) - self.add_pass(ConvertSplitToSlicePass()) - self.add_pass(QuantizeOperatorArguments()) # Fold Q/DQ nodes, insert INT8/INT32 rescales. - - self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] - self.add_pass(FuseDuplicateUsersPass()) - # TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or - # before FoldAndAnnotateQParamsPass but is unable to at the moment. - # Ticket: MLETORCH-1539 - self.add_pass(DecomposeLinearPass()) - self.add_pass(InsertRescaleInt32Pass()) + self.add_passes( + [ + FoldAndAnnotateQParamsPass(exported_program), + FuseDuplicateUsersPass(), + # TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or + # before FoldAndAnnotateQParamsPass but is unable to at the moment. + # Ticket: MLETORCH-1539 + DecomposeLinearPass(), + InsertRescaleInt32Pass(), + ] + ) # Node transformation passes (post q/dq folding) - - self.add_pass(DecomposeLogitPass()) - self.add_pass(DecomposeMaskedFill()) - self.add_pass(DecomposeRoundPass()) - self.add_pass(DecomposeAcoshPass()) - self.add_pass(DecomposeAsinhPass()) - self.add_pass(DecomposeCoshPass()) - self.add_pass(DecomposeAsinAndAcosPass()) - self.add_pass(DecomposeSqrtPass()) - self.add_pass(DecomposeAtanPass()) - self.add_pass(DecomposeAtanhPass()) - self.add_pass(DecomposeAddmmPass()) - self.add_pass(DecomposeEluPass()) - self.add_pass(DecomposeExpm1Pass()) - self.add_pass(ConvertIntPowToMuls()) - self.add_pass(CastBoolToInt8Pass()) - self.add_pass(DecomposeSinhPass()) - self.add_pass(DecomposeSignPass()) - self.add_pass(DecomposeFloorDividePass()) - self.add_pass(DecomposeGeluPass()) - self.add_pass(DecomposeAddSubAlphaPass()) - self.add_pass(DecomposeGroupedConv()) - self.add_pass(Conv1dUnsqueezePass()) + self.add_passes( + [ + DecomposeLogitPass(), + DecomposeMaskedFill(), + DecomposeRoundPass(), + DecomposeAcoshPass(), + DecomposeAsinhPass(), + DecomposeCoshPass(), + DecomposeAsinAndAcosPass(), + DecomposeSqrtPass(), + DecomposeAtanPass(), + DecomposeAtanhPass(), + DecomposeAddmmPass(), + DecomposeEluPass(), + DecomposeExpm1Pass(), + ConvertIntPowToMuls(), + CastBoolToInt8Pass(), + DecomposeSinhPass(), + DecomposeSignPass(), + DecomposeFloorDividePass(), + DecomposeGeluPass(), + DecomposeAddSubAlphaPass(), + DecomposeGroupedConv(), + Conv1dUnsqueezePass(), + ] + ) # Scalars -> tensors, match tensor dtypes and ranks. - - self.add_pass(ReplaceScalarWithTensorByProfilePass()) - self.add_pass(ConvertFullLikeToFullPass()) - self.add_pass(MatchArgDtypePass()) - self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) - # TODO: Move DecomposeNotEqualPass to before or after this block of - # passes. Ticket: MLETORCH-1540 - self.add_pass(DecomposeNotEqualPass()) - self.add_pass(MatchArgRanksPass(exported_program)) - self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_passes( + [ + ReplaceScalarWithTensorByProfilePass(), + ConvertFullLikeToFullPass(), + MatchArgDtypePass(), + UnsqueezeScalarPlaceholdersPass(exported_program), + # TODO: Move DecomposeNotEqualPass to before or after this block of + # passes. Ticket: MLETORCH-1540 + DecomposeNotEqualPass(), + MatchArgRanksPass(exported_program), + FuseConstantArgsPass(exported_program), + ] + ) # Node transformation passes (post scalar-removal) - - self.add_pass(DecomposeRemainderPass()) - self.add_pass(DecomposeDivTensorModePass()) - self.add_pass(DecomposeEmbeddingPass()) - self.add_pass(FuseBatchnorm2DPass(exported_program)) - self.add_pass(ConvertMmToBmmPass()) - self.add_pass(DecomposeGluPass()) - self.add_pass(DecomposeLeakyReLUPass()) - self.add_pass(DecomposeDivPass()) - self.add_pass(DecomposeSoftmaxPass()) - self.add_pass(ConvertMinMaxPass()) - self.add_pass(DecomposeAnyPass()) - self.add_pass(DecomposeAdaptiveAvgPool2dPass()) - self.add_pass(DecomposeAvgPool2d()) - self.add_pass( - DecorateFp32toInt32CastingPass() - ) # Require that no new fp32->int32 is introduced after this pass - self.add_pass(ComputeConstantOpsAOT(exported_program)) - self.add_pass(ConvertExpandCopyToRepeatPass()) - self.add_pass(UnsqueezeBeforeRepeatPass()) - self.add_pass(DecomposeCumsumPass(exported_program)) - self.add_pass(DecomposeMaxPool2DPass()) - self.add_pass(SizeAdjustInputPass()) - self.add_pass(DecomposeSelectPass()) - self.add_pass(ConvertSqueezesToViewPass()) - self.add_pass(CastToInt32Pass()) - self.add_pass(BroadcastArgsPass()) - self.add_pass(ConvertPermuteSingletonToViewPass()) - self.add_pass(FuseViewCopyTransformPass()) - self.add_pass(DecomposeConv2dWithInt16ActivationPass()) - self.add_pass(DecomposeSumPass()) - self.add_pass(InsertTableOpsPass(exported_program)) + self.add_passes( + [ + DecomposeRemainderPass(), + DecomposeDivTensorModePass(), + DecomposeEmbeddingPass(), + FuseBatchnorm2DPass(exported_program), + ConvertMmToBmmPass(), + DecomposeGluPass(), + DecomposeLeakyReLUPass(), + DecomposeDivPass(), + DecomposeSoftmaxPass(), + ConvertMinMaxPass(), + DecomposeAnyPass(), + DecomposeAdaptiveAvgPool2dPass(), + DecomposeAvgPool2d(), + DecorateFp32toInt32CastingPass(), + ComputeConstantOpsAOT(exported_program), + ConvertExpandCopyToRepeatPass(), + UnsqueezeBeforeRepeatPass(), + DecomposeCumsumPass(exported_program), + DecomposeMaxPool2DPass(), + SizeAdjustInputPass(), + DecomposeSelectPass(), + ConvertSqueezesToViewPass(), + CastToInt32Pass(), + BroadcastArgsPass(), + ConvertPermuteSingletonToViewPass(), + FuseViewCopyTransformPass(), + DecomposeConv2dWithInt16ActivationPass(), + DecomposeSumPass(), + InsertTableOpsPass(exported_program), + ] + ) # Aten -> TOSA transformation passes - - self.add_pass(RewriteUpsamplePass()) - self.add_pass(RewriteConv2dPass(exported_program)) - self.add_pass(RewriteMatmulPass()) + self.add_passes( + [ + RewriteUpsamplePass(), + RewriteConv2dPass(exported_program), + RewriteMatmulPass(), + ] + ) # Postprocessing/cleanup passes - - self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(FuseEqualPlaceholdersPass(exported_program)) - self.add_pass(ToTosaMemoryFormatPass(exported_program)) - self.add_pass(RemoveNoopPass()) - self.add_pass(InsertRescalePass()) + self.add_passes( + [ + CastInt64BuffersToInt32Pass(exported_program), + FuseEqualPlaceholdersPass(exported_program), + ToTosaMemoryFormatPass(exported_program), + RemoveNoopPass(), + InsertRescalePass(), + ] + ) self.validate_constraints_mandatory() return self._transform(graph_module) @@ -287,66 +310,73 @@ def transform_to_backend_pipeline( return self._tosa_pipeline(exported_program, graph_module) else: raise NotImplementedError( - f"No pass pipeline implemented for {self.tosa_spec=}" + f"No pass pipeline implemented for {self.tosa_spec}" ) def transform_for_annotation_pipeline(self, graph_module: GraphModule): # Preprocessing passes - - self.add_pass( - RemoveGraphAssertsPass() - ) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph + self.add_pass(RemoveGraphAssertsPass()) # Transformation passes (pre scalar -> tensor) - - self.add_pass(ConvertInt64ConstOpsToInt32Pass()) - self.add_pass(ConvertInt64OutputOpsToInt32Pass()) - self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass()) - self.add_pass(DecomposeEmbeddingPass()) - self.add_pass(DecomposeScaledDotProductAttention()) - self.add_pass(DecomposeRoundPass()) - self.add_pass(DecomposeLogitPass()) - self.add_pass(CastBoolToInt8Pass()) - self.add_pass(DecomposeSignPass()) - self.add_pass(DecomposeAddmmPass()) - self.add_pass(DecomposeRemainderPass()) - self.add_pass(DecomposeFloorDividePass()) - self.add_pass(DecomposeDivTensorModePass()) + self.add_passes( + [ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + InsertInt32CastsAfterInt64PlaceholdersPass(), + DecomposeEmbeddingPass(), + DecomposeScaledDotProductAttention(), + DecomposeRoundPass(), + DecomposeLogitPass(), + CastBoolToInt8Pass(), + DecomposeSignPass(), + DecomposeAddmmPass(), + DecomposeRemainderPass(), + DecomposeFloorDividePass(), + DecomposeDivTensorModePass(), + ] + ) # Scalars -> tensors - - self.add_pass(ReplaceScalarWithTensorByProfilePass()) - self.add_pass(ScalarsToAttributePass()) + self.add_passes( + [ + ReplaceScalarWithTensorByProfilePass(), + ScalarsToAttributePass(), + ] + ) # Transformation passes (post scalar removal) - - self.add_pass(DecomposeAddSubAlphaPass()) - self.add_pass(DecomposeGroupNormPass()) - self.add_pass(DecomposeLayerNormPass()) - self.add_pass(DecomposeVarPass()) - self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) - self.add_pass(DecomposeNotEqualPass()) - self.add_pass(DecomposeCosineSimilarityPass()) - self.add_pass(DecomposeGluPass()) - self.add_pass(DecomposeDivPass()) - self.add_pass(DecomposeLeakyReLUPass()) - self.add_pass(DecomposeLinearVectorNormPass()) - self.add_pass(DecomposeSqrtPass()) - self.add_pass(DecomposeSiluPass()) - self.add_pass(DecomposeAvgPool2d()) - if self.tosa_spec.is_U55_subset: - # Numerically stable softmax uses amax which is not supported on Ethos-U55 - self.add_pass(DecomposeSoftmaxUnstablePass()) - else: - self.add_pass(DecomposeSoftmaxPass()) - self.add_pass(ConvertMinMaxPass()) + self.add_passes( + [ + DecomposeAddSubAlphaPass(), + DecomposeGroupNormPass(), + DecomposeLayerNormPass(), + DecomposeVarPass(), + DecomposeMeanDimPass(graph_module, self.tosa_spec), + DecomposeNotEqualPass(), + DecomposeCosineSimilarityPass(), + DecomposeGluPass(), + DecomposeDivPass(), + DecomposeLeakyReLUPass(), + DecomposeLinearVectorNormPass(), + DecomposeSqrtPass(), + DecomposeSiluPass(), + DecomposeAvgPool2d(), + ( + DecomposeSoftmaxUnstablePass() + if self.tosa_spec.is_U55_subset + else DecomposeSoftmaxPass() + ), + ConvertMinMaxPass(), + ] + ) # Postprocessing passes - - self.add_pass(ReplaceInfValues()) - if not self.tosa_spec.is_U55_subset: - # Uses where which is not supported on Ethos-U55 - self.add_pass(DecomposeMaskedFill()) + self.add_passes( + [ + ReplaceInfValues(), + DecomposeMaskedFill() if not self.tosa_spec.is_U55_subset else None, + ] + ) return self._transform(graph_module) diff --git a/backends/arm/_passes/remove_graph_asserts_pass.py b/backends/arm/_passes/remove_graph_asserts_pass.py index 595fa55c5d1..a462c1182ee 100644 --- a/backends/arm/_passes/remove_graph_asserts_pass.py +++ b/backends/arm/_passes/remove_graph_asserts_pass.py @@ -6,9 +6,13 @@ from typing import Set, Type from executorch.backends.arm._passes.arm_pass import ArmPass + +from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import ( + ConvertInt64ConstOpsToInt32Pass, +) from executorch.exir.pass_base import ExportPass from executorch.exir.passes import remove_graph_asserts_pass class RemoveGraphAssertsPass(remove_graph_asserts_pass.RemoveGraphAssertsPass, ArmPass): - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ConvertInt64ConstOpsToInt32Pass}