Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
4da1550
Revert "Arm backend: Merge RetraceFoldedDtypesPass into FoldAndAnnota…
GregoryComer Oct 27, 2025
5b1a947
Revert "Arm backend: Serialize controlflow submodules. (#15381)"
GregoryComer Oct 27, 2025
4250b49
Revert "Arm backend: Remove pyre-unsafe from remaining files (#15391)"
GregoryComer Oct 27, 2025
d975528
Revert "Arm backend: Move rescales from ADD & SUB visitors to pass (#…
GregoryComer Oct 27, 2025
5c59fa0
Revert "Arm backend: Remove pyre-unsafe from operators/ (#15376)"
GregoryComer Oct 27, 2025
efae4a8
Revert "Arm backend: Align operator_validation_utils docstrings with …
GregoryComer Oct 27, 2025
3a50a0b
Revert "Arm backend: Deprecate internal models using aot_arm_compiler…
GregoryComer Oct 27, 2025
e83e676
Revert "Arm backend: Remove pyre-unsafe from quantizer/ & backends/ar…
GregoryComer Oct 27, 2025
ddfa961
Revert "Arm backend: Add docstrings for operator_support/embedding_su…
GregoryComer Oct 27, 2025
a4c3cd7
Revert "Arm backend: Merge passes that replace scalars (#15298)"
GregoryComer Oct 27, 2025
16f7f7a
Revert "Arm backend: Use reshape instead of view before edge (#15269)"
GregoryComer Oct 27, 2025
a71332d
Revert "Arm backend: Fix arg-type MyPy errors (#15367)"
GregoryComer Oct 27, 2025
e204ea6
Revert "Arm backend: Remove pyre-unsafe from _passes/ (#15351)"
GregoryComer Oct 27, 2025
eb2c876
Revert "Arm backend: Remove pyre-unsafe from tosa/, vgf/ and ethosu/ …
GregoryComer Oct 27, 2025
8167327
Revert "Arm backend: Move rescales from SUM visitor to pass (#15299)"
GregoryComer Oct 27, 2025
008a014
Revert "Arm backend: Tag control flow submodules in partitioner (#153…
GregoryComer Oct 27, 2025
a3ff326
Revert "Arm backend: support mean.default (#15363)"
GregoryComer Oct 27, 2025
53bb98b
Revert "Arm backend: Support per-channel in TOSA.RESCALE (#15267)"
GregoryComer Oct 27, 2025
36f25ce
Revert "ArBackend: Enable Pybindings for tosa_serialization lib (#153…
GregoryComer Oct 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ ignore_missing_imports = True
[mypy-tosa_tools.*]
ignore_missing_imports = True

[mypy-tosa_serializer]
ignore_missing_imports = True

[mypy-tosa_serializer.*]
ignore_missing_imports = True

[mypy-setuptools.*]
ignore_missing_imports = True

Expand Down
4 changes: 3 additions & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
FoldAndAnnotateQParamsPass,
QuantizeOperatorArguments,
RetraceFoldedDtypesPass,
)
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
Expand All @@ -87,7 +88,8 @@
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from .remove_noop_pass import RemoveNoopPass # noqa
from .replace_scalar_with_tensor_pass import ( # noqa
ReplaceScalarWithTensorByProfilePass,
ReplaceScalarWithTensorArgPassTOSABI,
ReplaceScalarWithTensorArgPassTOSAMI,
)
from .rewrite_conv2d_pass import RewriteConv2dPass # noqa
from .rewrite_matmul import RewriteMatmulPass # noqa
Expand Down
9 changes: 5 additions & 4 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools
import operator
Expand Down Expand Up @@ -51,7 +52,7 @@ def _match_partition_to_node(
raise RuntimeError(f"Cannot find an input node which matches, {node}.")

def call(self, graph_module: GraphModule) -> PassResult:
matmul_partitions_map = get_source_partitions(
matmul_partitions = get_source_partitions(
graph_module.graph,
[
torch.matmul,
Expand All @@ -60,7 +61,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
None,
)
matmul_partitions = list(
itertools.chain.from_iterable(matmul_partitions_map.values())
itertools.chain.from_iterable(matmul_partitions.values())
)
matmul_targets = {
exir_ops.edge.aten.bmm.default,
Expand Down Expand Up @@ -88,7 +89,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
# Create new dq-node before matmul
dq_node = create_node(
graph=graph_module.graph,
op_target=cast(EdgeOpOverload, input_node.target),
op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type]
)
dq_node.args = (node, *input_node.args[1:])
matmul_node.replace_input_with(node, dq_node)
Expand All @@ -109,7 +110,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
# Create q-node after matmul
q_node = create_node(
graph=graph_module.graph,
op_target=cast(EdgeOpOverload, partition_output.target),
op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type]
)
matmul_node.replace_all_uses_with(q_node)
q_node.args = (matmul_node, *partition_output.args[1:])
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/arm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import traceback
from abc import abstractmethod
Expand Down
47 changes: 26 additions & 21 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


from collections import defaultdict

Expand Down Expand Up @@ -87,7 +89,9 @@
QuantizeOperatorArguments,
RemoveNoopPass,
ReplaceInfValues,
ReplaceScalarWithTensorByProfilePass,
ReplaceScalarWithTensorArgPassTOSABI,
ReplaceScalarWithTensorArgPassTOSAMI,
RetraceFoldedDtypesPass,
RewriteConv2dPass,
RewriteMatmulPass,
RewriteUpsamplePass,
Expand Down Expand Up @@ -152,15 +156,15 @@ def _transform(self, graph_module: GraphModule):
with TosaLoweringContext(self.tosa_spec):
return self(graph_module).graph_module

def _tosa_INT_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
) -> GraphModule:
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(AnnotateOutputDimOrderPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
self.add_pass(
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
)
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
Expand All @@ -170,11 +174,12 @@ def _tosa_INT_pipeline(
self.add_pass(CastToInt32Pass())

self.add_pass(CastBoolToInt8Pass())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(ConvertELUParamsPass())
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
if self.tosa_spec.is_U55_subset:
Expand All @@ -189,6 +194,7 @@ def _tosa_INT_pipeline(
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(DecomposeSumPass())
self.add_pass(DecomposeCumsumPass(exported_program))
self.add_pass(Conv1dUnsqueezePass())
self.add_pass(DecomposeMaxPool2DPass())
Expand All @@ -209,18 +215,15 @@ def _tosa_INT_pipeline(
self.add_pass(RewriteMatmulPass())
self.add_pass(RewriteUpsamplePass())
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(InsertRescaleInt32Pass())
self.add_pass(DecomposeSumPass())
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
self.add_pass(InsertRescalePass())
self.add_pass(InsertRescaleInt32Pass())

self.validate_constraints_mandatory()
return self._transform(graph_module)
return self._transform(exported_program.graph_module)

def _tosa_FP_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
) -> GraphModule:
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(AnnotateOutputDimOrderPass())
self.add_pass(DecomposeExpm1Pass())
self.add_pass(DecomposeLogitPass())
Expand All @@ -241,7 +244,7 @@ def _tosa_FP_pipeline(
self.add_pass(DecomposeSinhPass())
self.add_pass(DecomposeSignPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
Expand All @@ -255,7 +258,9 @@ def _tosa_FP_pipeline(
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeBatchNormNoStatsPass())
self.add_pass(DecomposeVarPass())
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
self.add_pass(
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
)
self.add_pass(DecomposeNotEqualPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeAddSubAlphaPass())
Expand All @@ -269,6 +274,7 @@ def _tosa_FP_pipeline(
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
Expand Down Expand Up @@ -302,16 +308,14 @@ def _tosa_FP_pipeline(
self.add_pass(InsertRescalePass())

self.validate_constraints_mandatory()
return self._transform(graph_module)
return self._transform(exported_program.graph_module)

def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
):
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
"""Apply passes before transforming program to backend"""
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
return self._tosa_FP_pipeline(exported_program, graph_module)
return self._tosa_FP_pipeline(exported_program)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
return self._tosa_INT_pipeline(exported_program, graph_module)
return self._tosa_INT_pipeline(exported_program)
else:
raise NotImplementedError(
f"No pass pipeline implemented for {self.tosa_spec=}"
Expand All @@ -333,7 +337,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeAddmmPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(DecomposeAddSubAlphaPass())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeGroupNormPass())
self.add_pass(DecomposeLayerNormPass())
Expand All @@ -357,6 +361,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):

self.add_pass(ConvertMinMaxPass())
self.add_pass(ReplaceInfValues())
self.add_pass(DecomposeSumPass())

if not self.tosa_spec.is_U55_subset:
# Uses where which is not supported on Ethos-U55
Expand Down
10 changes: 4 additions & 6 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import traceback
from inspect import isclass
Expand All @@ -13,10 +14,8 @@
import torch
import torch.fx
from executorch.backends.arm.common.debug import get_node_debug_info
from executorch.backends.arm.common.type import ensure_type
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload

from torch._export.utils import (
get_buffer,
Expand Down Expand Up @@ -83,18 +82,17 @@ def get_param_tensor(
elif is_lifted_tensor_constant(exp_prog, node):
return get_lifted_tensor_constant(exp_prog, node)
elif is_get_attr_node(node):
target_node = ensure_type(str, node.target)
# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, target_node)
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
except AttributeError:
return getattr(exp_prog.graph_module, target_node)
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
raise RuntimeError(f"unsupported param type, {node.op}.")


def create_node(
graph: torch.fx.Graph,
op_target: OpOverload | EdgeOpOverload,
op_target: OpOverload,
args: tuple = (),
kwargs: Optional[dict] = None,
quantize: bool = False,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/cast_int64_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
from typing import Set, Type
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
from typing import cast, Set, Type
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/convert_int64_const_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import logging
from typing import Set, Type
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/convert_int64_output_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import logging
from typing import Set, Type
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/convert_int_pow_to_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Set, Type

Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Set, Type

Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/convert_squeezes_to_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Set, Type

Expand Down
5 changes: 3 additions & 2 deletions backends/arm/_passes/decompose_acosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Set, Type

Expand All @@ -12,7 +13,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorByProfilePass,
ReplaceScalarWithTensorArgPassTOSAMI,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand All @@ -32,7 +33,7 @@ class DecomposeAcoshPass(ArmPass):
DecomposeSqrtPass,
InsertTableOpsPass,
MatchArgRanksPass,
ReplaceScalarWithTensorByProfilePass,
ReplaceScalarWithTensorArgPassTOSAMI,
MatchArgDtypePass,
}

Expand Down
5 changes: 3 additions & 2 deletions backends/arm/_passes/decompose_asin_and_acos_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
from math import pi
Expand All @@ -19,7 +20,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorByProfilePass,
ReplaceScalarWithTensorArgPassTOSAMI,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -71,7 +72,7 @@ class DecomposeAsinAndAcosPass(ArmPass):
ConvertFullLikeToFullPass,
MatchArgRanksPass,
MatchArgDtypePass,
ReplaceScalarWithTensorByProfilePass,
ReplaceScalarWithTensorArgPassTOSAMI,
}

def _build_polynomial(
Expand Down
Loading
Loading