Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
from .convert_minmax_pass import ConvertMinMaxPass # noqa
from .convert_permute_singleton_to_view_pass import ( # noqa
ConvertPermuteSingletonToViewPass,
)
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
from .convert_to_clamp import ConvertToClampPass # noqa
from .convert_to_clamp_pass import ConvertToClampPass # noqa
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa
Expand All @@ -35,7 +34,7 @@
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
from .decompose_atan_pass import DecomposeAtanPass # noqa
from .decompose_atanh_pass import DecomposeAtanhPass # noqa
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
from .decompose_avg_pool2d_pass import DecomposeAvgPool2dPass # noqa
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
from .decompose_cosh_pass import DecomposeCoshPass # noqa
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
Expand All @@ -48,23 +47,24 @@
from .decompose_floor_divide_pass import DecomposeFloorDividePass # noqa
from .decompose_gelu_pass import DecomposeGeluPass # noqa
from .decompose_glu_pass import DecomposeGluPass # noqa
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
from .decompose_int16_activation_conv2d_pass import ( # noqa
DecomposeConv2dWithInt16ActivationPass,
)
from .decompose_int_pow_pass import DecomposeIntPowPass # noqa
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
from .decompose_linalg_vector_norm_pass import DecomposeLinalgVectorNormPass # noqa
from .decompose_linear_pass import DecomposeLinearPass # noqa
from .decompose_logit_pass import DecomposeLogitPass # noqa
from .decompose_masked_fill import DecomposeMaskedFill # noqa
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
from .decompose_round_pass import DecomposeRoundPass # noqa
from .decompose_sdpa_pass import DecomposeScaledDotProductAttention # noqa
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
from .decompose_select import DecomposeSelectPass # noqa
from .decompose_sign_pass import DecomposeSignPass # noqa
from .decompose_silu_pass import DecomposeSiluPass # noqa
Expand All @@ -77,10 +77,13 @@
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
FoldAndAnnotateQParamsPass,
QuantizeOperatorArguments,
QuantizeClampArgumentsPass,
)
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
from .fuse_constant_ops_pass import ( # noqa
ComputeConstantOpsAOTPass,
FuseConstantArgsPass,
)
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
Expand All @@ -107,5 +110,5 @@
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip
from .replace_inf_values_pass import ReplaceInfValuesPass # noqa # usort: skip
from .arm_pass_manager import ArmPassManager # noqa # usort: skip
48 changes: 24 additions & 24 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
CastBoolToInt8Pass,
CastInt64BuffersToInt32Pass,
CastToInt32Pass,
ComputeConstantOpsAOT,
ComputeConstantOpsAOTPass,
Conv1dUnsqueezePass,
ConvertELUParamsPass,
ConvertExpandCopyToRepeatPass,
ConvertFullLikeToFullPass,
ConvertInt64ConstOpsToInt32Pass,
ConvertInt64OutputOpsToInt32Pass,
ConvertIntPowToMuls,
ConvertMinMaxPass,
ConvertMmToBmmPass,
ConvertPermuteSingletonToViewPass,
Expand All @@ -40,7 +39,7 @@
DecomposeAsinhPass,
DecomposeAtanhPass,
DecomposeAtanPass,
DecomposeAvgPool2d,
DecomposeAvgPool2dPass,
DecomposeBatchNormNoStatsPass,
DecomposeConv2dWithInt16ActivationPass,
DecomposeCoshPass,
Expand All @@ -54,20 +53,21 @@
DecomposeFloorDividePass,
DecomposeGeluPass,
DecomposeGluPass,
DecomposeGroupedConv,
DecomposeGroupedConvPass,
DecomposeGroupNormPass,
DecomposeIntPowPass,
DecomposeLayerNormPass,
DecomposeLeakyReLUPass,
DecomposeLinalgVectorNormPass,
DecomposeLinearPass,
DecomposeLinearVectorNormPass,
DecomposeLogitPass,
DecomposeMaskedFill,
DecomposeMaxPool2DPass,
DecomposeMaskedFillPass,
DecomposeMaxPool2dPass,
DecomposeMeanDimPass,
DecomposeNotEqualPass,
DecomposeRemainderPass,
DecomposeRoundPass,
DecomposeScaledDotProductAttention,
DecomposeScaledDotProductAttentionPass,
DecomposeSelectPass,
DecomposeSignPass,
DecomposeSiluPass,
Expand All @@ -79,7 +79,7 @@
DecomposeVarPass,
DecorateFp32toInt32CastingPass,
FoldAndAnnotateQParamsPass,
FuseBatchnorm2DPass,
FuseBatchNorm2dPass,
FuseConstantArgsPass,
FuseDuplicateUsersPass,
FuseEqualPlaceholdersPass,
Expand All @@ -91,11 +91,11 @@
InsertTableOpsPass,
MatchArgDtypePass,
MatchArgRanksPass,
QuantizeOperatorArguments,
QuantizeClampArgumentsPass,
RemoveGetItemPass,
RemoveGraphAssertsPass,
RemoveNoopPass,
ReplaceInfValues,
ReplaceInfValuesPass,
ReplaceScalarWithTensorByProfilePass,
RewriteConv2dPass,
RewriteMatmulPass,
Expand Down Expand Up @@ -181,7 +181,7 @@ def _tosa_pipeline(
AnnotateDecomposedMatmulPass(),
ConvertELUParamsPass(),
ConvertSplitToSlicePass(),
QuantizeOperatorArguments(),
QuantizeClampArgumentsPass(),
]
)

Expand All @@ -202,7 +202,7 @@ def _tosa_pipeline(
self.add_passes(
[
DecomposeLogitPass(),
DecomposeMaskedFill(),
DecomposeMaskedFillPass(),
DecomposeRoundPass(),
DecomposeAcoshPass(),
DecomposeAsinhPass(),
Expand All @@ -214,14 +214,14 @@ def _tosa_pipeline(
DecomposeAddmmPass(),
DecomposeEluPass(),
DecomposeExpm1Pass(),
ConvertIntPowToMuls(),
DecomposeIntPowPass(),
CastBoolToInt8Pass(),
DecomposeSinhPass(),
DecomposeSignPass(),
DecomposeFloorDividePass(),
DecomposeGeluPass(),
DecomposeAddSubAlphaPass(),
DecomposeGroupedConv(),
DecomposeGroupedConvPass(),
Conv1dUnsqueezePass(),
]
)
Expand All @@ -247,7 +247,7 @@ def _tosa_pipeline(
DecomposeRemainderPass(),
DecomposeDivTensorModePass(),
DecomposeEmbeddingPass(),
FuseBatchnorm2DPass(exported_program),
FuseBatchNorm2dPass(exported_program),
ConvertMmToBmmPass(),
DecomposeGluPass(),
DecomposeLeakyReLUPass(),
Expand All @@ -256,13 +256,13 @@ def _tosa_pipeline(
ConvertMinMaxPass(),
DecomposeAnyPass(),
DecomposeAdaptiveAvgPool2dPass(),
DecomposeAvgPool2d(),
DecomposeAvgPool2dPass(),
DecorateFp32toInt32CastingPass(),
ComputeConstantOpsAOT(exported_program),
ComputeConstantOpsAOTPass(exported_program),
ConvertExpandCopyToRepeatPass(),
UnsqueezeBeforeRepeatPass(),
DecomposeCumsumPass(exported_program),
DecomposeMaxPool2DPass(),
DecomposeMaxPool2dPass(),
SizeAdjustInputPass(),
DecomposeSelectPass(),
ConvertSqueezesToViewPass(),
Expand Down Expand Up @@ -324,7 +324,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
ConvertInt64OutputOpsToInt32Pass(),
InsertInt32CastsAfterInt64PlaceholdersPass(),
DecomposeEmbeddingPass(),
DecomposeScaledDotProductAttention(),
DecomposeScaledDotProductAttentionPass(),
DecomposeRoundPass(),
DecomposeLogitPass(),
CastBoolToInt8Pass(),
Expand Down Expand Up @@ -357,10 +357,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
DecomposeGluPass(),
DecomposeDivPass(),
DecomposeLeakyReLUPass(),
DecomposeLinearVectorNormPass(),
DecomposeLinalgVectorNormPass(),
DecomposeSqrtPass(),
DecomposeSiluPass(),
DecomposeAvgPool2d(),
DecomposeAvgPool2dPass(),
(
DecomposeSoftmaxUnstablePass()
if self.tosa_spec.is_U55_subset
Expand All @@ -373,8 +373,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
# Postprocessing passes
self.add_passes(
[
ReplaceInfValues(),
DecomposeMaskedFill() if not self.tosa_spec.is_U55_subset else None,
ReplaceInfValuesPass(),
DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None,
]
)

Expand Down
6 changes: 4 additions & 2 deletions backends/arm/_passes/convert_full_like_to_full_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from typing import Set, Type

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
ComputeConstantOpsAOTPass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand All @@ -24,7 +26,7 @@ class ConvertFullLikeToFullPass(ArmPass):
Skip layout and device since it's not relevant for our backend.
"""

_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}

def call_operator(self, op, args, kwargs, meta):
if op not in [
Expand Down
11 changes: 8 additions & 3 deletions backends/arm/_passes/convert_int64_const_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
ComputeConstantOpsAOTPass,
)
from executorch.exir.pass_base import ExportPass, PassResult


Expand All @@ -30,7 +32,7 @@ class ConvertInt64ConstOpsToInt32Pass(ArmPass):
5. `torch.tensor`
"""

_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}

torch_ops = [
torch.ops.aten.full.default,
Expand All @@ -47,7 +49,10 @@ def call(self, graph_module: torch.fx.GraphModule):
if node.op != "call_function":
continue

if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops:
if (
node.target
not in ComputeConstantOpsAOTPass.targeted_ops + self.torch_ops
):
continue

data = node.target(*node.args, **node.kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from executorch.backends.arm._passes import ArmPass

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
QuantizeOperatorArguments,
QuantizeClampArgumentsPass,
)

from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -30,7 +30,7 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]:


class ConvertToClampPass(ArmPass):
_passes_required_after: Set[Type[ExportPass]] = {QuantizeOperatorArguments}
_passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass}

def call_operator(self, op, args, kwargs, meta):
if op not in edge_operators:
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d
from executorch.backends.arm._passes.decompose_avg_pool2d_pass import (
DecomposeAvgPool2dPass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata
Expand Down Expand Up @@ -44,7 +46,7 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass):
The output is of size output_size_h x output_size_w for any input.
"""

_passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2d}
_passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass}

def call_operator(self, op, args, kwargs, meta, updated=False):
if op not in (edge_ops + aten_ops):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
ComputeConstantOpsAOTPass,
)
from executorch.backends.arm.operators.operator_validation_utils import (
adjust_pooling_pad_if_needed,
)
Expand Down Expand Up @@ -37,8 +39,8 @@ def get_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}")


class DecomposeAvgPool2d(ArmPass):
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}
class DecomposeAvgPool2dPass(ArmPass):
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_div_ops + aten_div_ops):
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/_passes/decompose_batch_norm_no_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
ComputeConstantOpsAOTPass,
)

from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -37,7 +39,7 @@ class DecomposeBatchNormNoStatsPass(ArmPass):
"""

_passes_required_after: Set[Type[ExportPass]] = {
ComputeConstantOpsAOT,
ComputeConstantOpsAOTPass,
InsertTableOpsPass,
}

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_expm1_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import Set, Type

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.convert_int_pow_to_mul import ConvertIntPowToMuls
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
from executorch.backends.arm._passes.decompose_int_pow_pass import DecomposeIntPowPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
Expand Down Expand Up @@ -80,7 +80,7 @@ class DecomposeExpm1Pass(ArmPass):
"""

_passes_required_after: Set[Type[ExportPass]] = {
ConvertIntPowToMuls,
DecomposeIntPowPass,
InsertTableOpsPass,
DecomposeDivPass,
ReplaceScalarWithTensorByProfilePass,
Expand Down
Loading
Loading