Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
)

from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -29,7 +32,7 @@ class AnnotateDecomposedMatmulPass(ExportPass):
matmul-op (can be mm or bmm).
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {FoldAndAnnotateQParamsPass}

def _match_partition_to_node(
self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node]
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from typing import Set, Type

from executorch.backends.arm._passes.add_bias_pass import AddBiasPass
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -23,7 +26,7 @@ class Conv1dUnsqueezePass(ExportPass):
3) squeeze the output back down to 3d.
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {AddBiasPass, SizeAdjustInputPass}

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.convolution.default:
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/_passes/convert_any_default_dim_dims_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from typing import Set, Type

import torch
from executorch.backends.arm._passes.convert_squeezes_to_view import (
ConvertSqueezesToViewPass,
)
from executorch.exir.dialects._ops import ( # type: ignore[import-not-found]
ops as exir_ops,
)
Expand Down Expand Up @@ -46,7 +49,7 @@ class ConvertAnyDefaultDimDimsPass(ExportPass):
squeeze(dim = [dim1, dim2])
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass}

def call(self, graph_module: torch.fx.GraphModule):
modified = False
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import torch

from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
UnsqueezeBeforeRepeatPass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -50,7 +53,7 @@ class ConvertExpandCopyToRepeatPass(ExportPass):
Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions.
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass}

expand_copy = exir_ops.edge.aten.expand_copy.default
repeat = exir_ops.edge.aten.repeat.default
Expand Down
7 changes: 5 additions & 2 deletions backends/arm/_passes/convert_full_like_to_full_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

from typing import Set, Type

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class ConvertFullLikeToFullPass(ExportPass):
class ConvertFullLikeToFullPass(ArmPass):
"""As per the full_like pytorch documentation,
`torch.full_like(input, fill_value)` is equivalent to
`torch.full(input.size(),
Expand All @@ -21,7 +24,7 @@ class ConvertFullLikeToFullPass(ExportPass):
Skip layout and device since it's not relevant for our backend.
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}

def call_operator(self, op, args, kwargs, meta):
if op not in [
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/convert_int64_const_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ConvertInt64ConstOpsToInt32Pass(ExportPass):
5. `torch.tensor`
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}

torch_ops = [
torch.ops.aten.full.default,
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/_passes/convert_minmax_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from typing import Set, Type

import torch
from executorch.backends.arm._passes.convert_squeezes_to_view import (
ConvertSqueezesToViewPass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -31,7 +34,7 @@ class ConvertMinMaxPass(ExportPass):
squeeze(dim = [dim1, dim2])
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass}

def check_argmax(self, node):
"""
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/_passes/convert_squeezes_to_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from typing import Set, Type

from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -17,7 +19,7 @@ class ConvertSqueezesToViewPass(ExportPass):
Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors.
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransform}

def call_operator(self, op, args, kwargs, meta):
if op not in [
Expand Down
6 changes: 5 additions & 1 deletion backends/arm/_passes/convert_to_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

from typing import Set, Tuple, Type

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
QuantizeOperatorArguments,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -24,7 +28,7 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]:


class ConvertToClampPass(ExportPass):
_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {QuantizeOperatorArguments}

def call_operator(self, op, args, kwargs, meta):
if op not in edge_operators:
Expand Down
15 changes: 14 additions & 1 deletion backends/arm/_passes/decompose_acosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
from typing import Set, Type

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass # noqa
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -22,7 +29,13 @@ class DecomposeAcoshPass(ArmPass):
acosh(x) = log(x + sqrt((x-1)(x+1))
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {
DecomposeSqrtPass,
InsertTableOpsPass,
MatchArgRanksPass,
ReplaceScalarWithTensorArgPassTOSAMI,
MatchArgDtypePass,
}

def call_operator(self, op, args, kwargs, meta, updated=False):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -43,7 +44,7 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass):
The output is of size output_size_h x output_size_w for any input.
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2d}

def call_operator(self, op, args, kwargs, meta, updated=False):
if op not in (edge_ops + aten_ops):
Expand Down
9 changes: 8 additions & 1 deletion backends/arm/_passes/decompose_addmm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -39,7 +42,11 @@ def get_ops(op):
class DecomposeAddmmPass(ArmPass):
"""Decomposes the addmm operator into tensor multiplication and addition."""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {
ConvertMmToBmmPass,
MatchArgRanksPass,
MatchArgDtypePass,
}

def call_operator(self, op, args, kwargs, meta):
if op not in [edge_addmm, aten_addmm]:
Expand Down
19 changes: 18 additions & 1 deletion backends/arm/_passes/decompose_asin_and_acos_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
ConvertFullLikeToFullPass,
)
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -56,7 +66,14 @@ class DecomposeAsinAndAcosPass(ArmPass):

"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {
DecomposeSqrtPass,
DecomposeDivPass,
ConvertFullLikeToFullPass,
MatchArgRanksPass,
MatchArgDtypePass,
ReplaceScalarWithTensorArgPassTOSAMI,
}

def _build_polynomial(
self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str]
Expand Down
15 changes: 14 additions & 1 deletion backends/arm/_passes/decompose_asinh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from typing import Set, Type

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -23,7 +30,13 @@ class DecomposeAsinhPass(ArmPass):
asinh(x) = log(x + sqrt(x^2 + 1))
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {
DecomposeSqrtPass,
InsertTableOpsPass,
MatchArgRanksPass,
ReplaceScalarWithTensorArgPassTOSAMI,
MatchArgDtypePass,
}

def call_operator(self, op, args, kwargs, meta):
if op not in edge_asinh_op:
Expand Down
13 changes: 12 additions & 1 deletion backends/arm/_passes/decompose_atan_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from typing import Set, Type

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -37,7 +43,12 @@ def _get_atan_ops(op):
class DecomposeAtanPass(ArmPass):
"""Decomposes the atan operator into a rational (Padé) approximation."""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {
InsertTableOpsPass,
MatchArgRanksPass,
MatchArgDtypePass,
ReplaceScalarWithTensorArgPassTOSAMI,
}

def _rational_approximation(self, z, ops, meta):
"""Creates a (2,1) Padé approximation for atan(x) on [-1, 1]."""
Expand Down
13 changes: 12 additions & 1 deletion backends/arm/_passes/decompose_atanh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from typing import Set, Type

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -33,7 +39,12 @@ class DecomposeAtanhPass(ArmPass):
atanh(x) = 0.5 * log((1 + x) / (1 - x))
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {
InsertTableOpsPass,
MatchArgRanksPass,
MatchArgDtypePass,
ReplaceScalarWithTensorArgPassTOSAMI,
}

def call_operator(self, op, args, kwargs, meta):
if op is not edge_atanh:
Expand Down
5 changes: 3 additions & 2 deletions backends/arm/_passes/decompose_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Set, Type

import torch
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
from executorch.backends.arm.operators.operator_validation_utils import (
adjust_pooling_pad_if_needed,
)
Expand All @@ -32,11 +33,11 @@ def get_decomposition(op) -> tuple:
torch.ops.aten.avg_pool2d.default,
torch.ops.aten.mul.Tensor,
)
raise RuntimeError(f"Can't get div decomposition for op {op}")
raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}")


class DecomposeAvgPool2d(ExportPass):
_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_div_ops + aten_div_ops):
Expand Down
8 changes: 7 additions & 1 deletion backends/arm/_passes/decompose_batch_norm_no_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT

from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -34,7 +37,10 @@ class DecomposeBatchNormNoStatsPass(ArmPass):
Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {
ComputeConstantOpsAOT,
InsertTableOpsPass,
}

def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
bn_ops = (
Expand Down
Loading
Loading