Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .arm_pass import ArmPass # noqa # usort: skip
from .add_bias_pass import AddBiasPass # noqa
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa
from .broadcast_args_pass import BroadcastArgsPass # noqa
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
Expand Down Expand Up @@ -82,7 +83,7 @@
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from .remove_clone_pass import RemoveClonePass # noqa
from .remove_noop_pass import RemoveNoopPass # noqa
from .replace_scalar_with_tensor_pass import ( # noqa
ReplaceScalarWithTensorArgPassTOSABI,
ReplaceScalarWithTensorArgPassTOSAMI,
Expand Down
21 changes: 21 additions & 0 deletions backends/arm/_passes/annotate_output_dim_order_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders
from executorch.exir.pass_base import PassResult


class AnnotateOutputDimOrderPass(ArmPass):
"""
Stores the current output dim_orders in the meta dict of the output node. This is used
for verifying that the dim order does not change unexpectedly in later passes.
"""

def call(self, graph_module):
output_node = graph_module.graph.output_node()
output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module)

return PassResult(graph_module, True)
11 changes: 7 additions & 4 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from executorch.backends.arm._passes import (
AddBiasPass,
AnnotateDecomposedMatmulPass,
AnnotateOutputDimOrderPass,
BroadcastArgsPass,
CastBoolToInt8Pass,
CastInt64BuffersToInt32Pass,
Expand Down Expand Up @@ -81,7 +82,7 @@
MatchArgDtypePass,
MatchArgRanksPass,
QuantizeOperatorArguments,
RemoveClonePass,
RemoveNoopPass,
ReplaceInfValues,
ReplaceScalarWithTensorArgPassTOSABI,
ReplaceScalarWithTensorArgPassTOSAMI,
Expand Down Expand Up @@ -119,6 +120,7 @@ def _transform(self, graph_module: GraphModule):
return self(graph_module).graph_module

def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(AnnotateOutputDimOrderPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
Expand Down Expand Up @@ -152,7 +154,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(DecomposeGroupedConv())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
Expand All @@ -171,11 +172,13 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
self.add_pass(InsertRescalePass())

return self._transform(exported_program.graph_module)

def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(AnnotateOutputDimOrderPass())
self.add_pass(DecomposeExpm1Pass())
self.add_pass(DecomposeLogitPass())
self.add_pass(DecomposeMaskedFill())
Expand Down Expand Up @@ -235,10 +238,8 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(DecomposeGroupedConv())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(DecomposeSumPass())
self.add_pass(DecomposeCumsumPass(exported_program))
self.add_pass(Conv1dUnsqueezePass())
Expand All @@ -249,10 +250,12 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:

self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(AddBiasPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
self.add_pass(InsertRescalePass())

return self._transform(exported_program.graph_module)
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,8 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value):
node.kwargs = kwargs
else:
raise RuntimeError("Invalid type")


def get_output_dim_orders(graph_module):
output_node = graph_module.graph.output_node()
return [get_first_fake_tensor(node).dim_order() for node in output_node.args[0]]
4 changes: 2 additions & 2 deletions backends/arm/_passes/convert_int64_output_ops_to_int32.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ class ConvertInt64OutputOpsToInt32Pass(ExportPass):

def _get_decomposition(self, op):
if op in self.edge_ops:
return exir_ops.edge.aten._to_copy.default
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default

if op in self.aten_ops:
return torch.ops.aten._to_copy.default
return torch.ops.dim_order_ops._to_dim_order_copy.default

raise RuntimeError(
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
Expand Down
5 changes: 2 additions & 3 deletions backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,17 @@ class DecorateFp32toInt32CastingPass(ArmPass):
To lower pytorch fp32 -> int32 casting to TOSA,
we need to transform the value with Ceil, Floor, and Where.
Before:
output = to_copy(x, dtype=torch.int32)
output = to_dim_order_copy(x, dtype=torch.int32)
After:
%zero = full((1,), 0.0, dtype=torch.float32)
is_non_negative = x >= %zero
floor_x = floor(x)
ceil_x = ceil(x)
decorated_x = where(is_non_negative, floor_x, ceil_x)
output = to_copy(decorated_x, dtype=torch.int32)
output = to_dim_order_copy(decorated_x, dtype=torch.int32)
"""

targets = [
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
]

Expand Down
9 changes: 4 additions & 5 deletions backends/arm/_passes/insert_int64_input_cast_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class InsertCastForOpsWithInt64InputPass(ExportPass):

def get_decomposition(self, op):
if op in self.edge_ops:
return exir_ops.edge.aten._to_copy.default
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default

if op in self.aten_ops:
return torch.ops.aten._to_copy.default
return torch.ops.dim_order_ops._to_dim_order_copy.default

raise RuntimeError(
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
Expand All @@ -56,15 +56,14 @@ def _check_aten_embedding_within_int32(self, weights, indices, node: torch.fx.No
return True

def _insert_int32_cast_before_node(self, graph, node, original_input):
to_copy_op = self.get_decomposition(node.target)
to_dim_order_copy_op = self.get_decomposition(node.target)
with graph.inserting_before(node):
cast_before = create_node(
graph,
to_copy_op,
to_dim_order_copy_op,
args=(original_input,),
kwargs={
"dtype": torch.int32,
"memory_format": torch.preserve_format,
},
)
node.replace_input_with(original_input, cast_before)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,20 @@
logger = logging.getLogger(__name__)


class RemoveClonePass(ExportPass):
"""Remove all clones from graph_module"""
class RemoveNoopPass(ExportPass):
"""Remove no-ops from graph_module"""

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.dim_order_ops._clone_dim_order.default:
if op not in (
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
):
return super().call_operator(op, args, kwargs, meta)

if len(args) != 1:
raise ValueError(
f"clone operator expects exactly one argument, got {len(args)}"
)
input_dtype = args[0].data.dtype
output_dtype = kwargs.get("dtype", input_dtype)

if "memory_format" in kwargs:
logger.warning(
f"Removing clone with memory_format '{kwargs['memory_format']}'."
)
if input_dtype != output_dtype:
return super().call_operator(op, args, kwargs, meta)

return args[0]
53 changes: 53 additions & 0 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@
# pyre-unsafe


import logging

import torch
from executorch.backends.arm._passes import AnnotateOutputDimOrderPass
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
get_output_dim_orders,
is_param_node,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

logger = logging.getLogger(__name__)


def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
"""
Expand Down Expand Up @@ -250,10 +256,27 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
node, input_node, graph_module
)

def remove_dim_order_kwargs(
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
):
if node.op != "call_function":
return

kwargs = dict(node.kwargs)

if "dim_order" in kwargs:
logger.warning(
f"Ignoring dim_order kwarg '{kwargs['dim_order']}' for '{node.name}'."
)
del kwargs["dim_order"]

node.kwargs = kwargs

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
node_data = get_first_fake_tensor(node).data

self.remove_dim_order_kwargs(graph_module, node)
# Inputs and outputs are always in (N)NCHW format
if _is_input(node, self.exported_program) or node.op == "output":
dim_order = tuple(range(node_data.dim()))
Expand All @@ -269,10 +292,40 @@ def call(self, graph_module: torch.fx.GraphModule):
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]

node.meta["tosa_dim_order"] = dim_order

# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
# See insert_tosa_transposes for insertion conditions.
self.insert_tosa_transposes(graph_module)
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)

def requires(self, graph_module) -> None:
"""
This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline.
"""

dim_orders = get_output_dim_orders(graph_module)
original_dim_orders = graph_module.graph.output_node().meta.get(
"original_dim_orders"
)
output_node = graph_module.graph.output_node()

if original_dim_orders is None:
raise RuntimeError(
f"{AnnotateOutputDimOrderPass.__name__} must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run."
)

if len(dim_orders) != len(original_dim_orders):
raise RuntimeError(
f"The number of outputs has changed since {AnnotateOutputDimOrderPass.__name__} was run."
)

for node, dim_order, original_dim_order in zip(
output_node.args[0], dim_orders, original_dim_orders
):
if dim_order != original_dim_order:
raise RuntimeError(
f"The dim order of output {node.name} has changed from {original_dim_order} to {dim_order} since {AnnotateOutputDimOrderPass.__name__} was run."
)
4 changes: 2 additions & 2 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-unsafe

from . import ( # noqa
clone_support,
clone_dim_order_support,
convolution_support,
embedding_support,
ethos_u55_support,
Expand All @@ -18,6 +18,6 @@
right_shift_support,
sin_cos_support,
slice_copy_support,
to_copy_support,
to_dim_order_copy_support,
tosa_supported_operators,
)
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,4 @@ def is_node_tosa_supported(
)
return False

# Check memory format
if "memory_format" in node.kwargs:
if node.kwargs["memory_format"] in (torch.preserve_format,):
self.reporter.report_reject(
node,
f"Argument 'memory_format' is not supported for "
f"{node.target} right now.",
)
return False

# Check dim_order
if "dim_order" in node.kwargs:
dim_order = node.kwargs["dim_order"]
# pyre-ignore[6]
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
self.reporter.report_reject(
node,
f"Argument {dim_order=} is not supported for "
f"{node.target} right now.",
)
return False

return True
Loading
Loading