diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index cb90eef01d1..f033e0d5322 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -8,6 +8,7 @@ from .arm_pass import ArmPass # noqa # usort: skip from .add_bias_pass import AddBiasPass # noqa from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa +from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa @@ -82,7 +83,7 @@ from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa -from .remove_clone_pass import RemoveClonePass # noqa +from .remove_noop_pass import RemoveNoopPass # noqa from .replace_scalar_with_tensor_pass import ( # noqa ReplaceScalarWithTensorArgPassTOSABI, ReplaceScalarWithTensorArgPassTOSAMI, diff --git a/backends/arm/_passes/annotate_output_dim_order_pass.py b/backends/arm/_passes/annotate_output_dim_order_pass.py new file mode 100644 index 00000000000..08f93383a9c --- /dev/null +++ b/backends/arm/_passes/annotate_output_dim_order_pass.py @@ -0,0 +1,21 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders +from executorch.exir.pass_base import PassResult + + +class AnnotateOutputDimOrderPass(ArmPass): + """ + Stores the current output dim_orders in the meta dict of the output node. This is used + for verifying that the dim order does not change unexpectedly in later passes. + """ + + def call(self, graph_module): + output_node = graph_module.graph.output_node() + output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module) + + return PassResult(graph_module, True) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 6aae943881d..c26cd8fb078 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -11,6 +11,7 @@ from executorch.backends.arm._passes import ( AddBiasPass, AnnotateDecomposedMatmulPass, + AnnotateOutputDimOrderPass, BroadcastArgsPass, CastBoolToInt8Pass, CastInt64BuffersToInt32Pass, @@ -81,7 +82,7 @@ MatchArgDtypePass, MatchArgRanksPass, QuantizeOperatorArguments, - RemoveClonePass, + RemoveNoopPass, ReplaceInfValues, ReplaceScalarWithTensorArgPassTOSABI, ReplaceScalarWithTensorArgPassTOSAMI, @@ -119,6 +120,7 @@ def _transform(self, graph_module: GraphModule): return self(graph_module).graph_module def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + self.add_pass(AnnotateOutputDimOrderPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertSplitToSlicePass()) @@ -152,7 +154,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(DecomposeGroupedConv()) - self.add_pass(RemoveClonePass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) @@ -171,11 +172,13 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program)) + self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) return self._transform(exported_program.graph_module) def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + self.add_pass(AnnotateOutputDimOrderPass()) self.add_pass(DecomposeExpm1Pass()) self.add_pass(DecomposeLogitPass()) self.add_pass(DecomposeMaskedFill()) @@ -235,10 +238,8 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(DecomposeGroupedConv()) - self.add_pass(RemoveClonePass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) - self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(DecomposeSumPass()) self.add_pass(DecomposeCumsumPass(exported_program)) self.add_pass(Conv1dUnsqueezePass()) @@ -249,10 +250,12 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(AddBiasPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program)) + self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) return self._transform(exported_program.graph_module) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 00eb395be9f..71e2030958f 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -235,3 +235,8 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value): node.kwargs = kwargs else: raise RuntimeError("Invalid type") + + +def get_output_dim_orders(graph_module): + output_node = graph_module.graph.output_node() + return [get_first_fake_tensor(node).dim_order() for node in output_node.args[0]] diff --git a/backends/arm/_passes/convert_int64_output_ops_to_int32.py b/backends/arm/_passes/convert_int64_output_ops_to_int32.py index d3803c82ffc..788201be6c8 100644 --- a/backends/arm/_passes/convert_int64_output_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_output_ops_to_int32.py @@ -68,10 +68,10 @@ class ConvertInt64OutputOpsToInt32Pass(ExportPass): def _get_decomposition(self, op): if op in self.edge_ops: - return exir_ops.edge.aten._to_copy.default + return exir_ops.edge.dim_order_ops._to_dim_order_copy.default if op in self.aten_ops: - return torch.ops.aten._to_copy.default + return torch.ops.dim_order_ops._to_dim_order_copy.default raise RuntimeError( f"[{self.__class__.__name__}] Can't get decomposition for op {op}" diff --git a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py index d6f7ac2ceac..17a682c0a8e 100644 --- a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py +++ b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py @@ -30,18 +30,17 @@ class DecorateFp32toInt32CastingPass(ArmPass): To lower pytorch fp32 -> int32 casting to TOSA, we need to transform the value with Ceil, Floor, and Where. Before: - output = to_copy(x, dtype=torch.int32) + output = to_dim_order_copy(x, dtype=torch.int32) After: %zero = full((1,), 0.0, dtype=torch.float32) is_non_negative = x >= %zero floor_x = floor(x) ceil_x = ceil(x) decorated_x = where(is_non_negative, floor_x, ceil_x) - output = to_copy(decorated_x, dtype=torch.int32) + output = to_dim_order_copy(decorated_x, dtype=torch.int32) """ targets = [ - exir_ops.edge.aten._to_copy.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] diff --git a/backends/arm/_passes/insert_int64_input_cast_pass.py b/backends/arm/_passes/insert_int64_input_cast_pass.py index 8864d6bb4eb..9577c920c1c 100644 --- a/backends/arm/_passes/insert_int64_input_cast_pass.py +++ b/backends/arm/_passes/insert_int64_input_cast_pass.py @@ -31,10 +31,10 @@ class InsertCastForOpsWithInt64InputPass(ExportPass): def get_decomposition(self, op): if op in self.edge_ops: - return exir_ops.edge.aten._to_copy.default + return exir_ops.edge.dim_order_ops._to_dim_order_copy.default if op in self.aten_ops: - return torch.ops.aten._to_copy.default + return torch.ops.dim_order_ops._to_dim_order_copy.default raise RuntimeError( f"[{self.__class__.__name__}] Can't get decomposition for op {op}" @@ -56,15 +56,14 @@ def _check_aten_embedding_within_int32(self, weights, indices, node: torch.fx.No return True def _insert_int32_cast_before_node(self, graph, node, original_input): - to_copy_op = self.get_decomposition(node.target) + to_dim_order_copy_op = self.get_decomposition(node.target) with graph.inserting_before(node): cast_before = create_node( graph, - to_copy_op, + to_dim_order_copy_op, args=(original_input,), kwargs={ "dtype": torch.int32, - "memory_format": torch.preserve_format, }, ) node.replace_input_with(original_input, cast_before) diff --git a/backends/arm/_passes/remove_clone_pass.py b/backends/arm/_passes/remove_noop_pass.py similarity index 53% rename from backends/arm/_passes/remove_clone_pass.py rename to backends/arm/_passes/remove_noop_pass.py index 15e8a9e9201..623517aac59 100644 --- a/backends/arm/_passes/remove_clone_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -14,21 +14,20 @@ logger = logging.getLogger(__name__) -class RemoveClonePass(ExportPass): - """Remove all clones from graph_module""" +class RemoveNoopPass(ExportPass): + """Remove no-ops from graph_module""" def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.dim_order_ops._clone_dim_order.default: + if op not in ( + exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + ): return super().call_operator(op, args, kwargs, meta) - if len(args) != 1: - raise ValueError( - f"clone operator expects exactly one argument, got {len(args)}" - ) + input_dtype = args[0].data.dtype + output_dtype = kwargs.get("dtype", input_dtype) - if "memory_format" in kwargs: - logger.warning( - f"Removing clone with memory_format '{kwargs['memory_format']}'." - ) + if input_dtype != output_dtype: + return super().call_operator(op, args, kwargs, meta) return args[0] diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index e5d810676d1..e4436d638f4 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -6,16 +6,22 @@ # pyre-unsafe +import logging + import torch +from executorch.backends.arm._passes import AnnotateOutputDimOrderPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, + get_output_dim_orders, is_param_node, ) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +logger = logging.getLogger(__name__) + def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool: """ @@ -250,10 +256,27 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): node, input_node, graph_module ) + def remove_dim_order_kwargs( + self, graph_module: torch.fx.GraphModule, node: torch.fx.Node + ): + if node.op != "call_function": + return + + kwargs = dict(node.kwargs) + + if "dim_order" in kwargs: + logger.warning( + f"Ignoring dim_order kwarg '{kwargs['dim_order']}' for '{node.name}'." + ) + del kwargs["dim_order"] + + node.kwargs = kwargs + def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: node_data = get_first_fake_tensor(node).data + self.remove_dim_order_kwargs(graph_module, node) # Inputs and outputs are always in (N)NCHW format if _is_input(node, self.exported_program) or node.op == "output": dim_order = tuple(range(node_data.dim())) @@ -269,6 +292,7 @@ def call(self, graph_module: torch.fx.GraphModule): dim_order = tuple(range(node_data.dim())) # type: ignore[assignment] node.meta["tosa_dim_order"] = dim_order + # Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format. # See insert_tosa_transposes for insertion conditions. self.insert_tosa_transposes(graph_module) @@ -276,3 +300,32 @@ def call(self, graph_module: torch.fx.GraphModule): graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) + + def requires(self, graph_module) -> None: + """ + This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline. + """ + + dim_orders = get_output_dim_orders(graph_module) + original_dim_orders = graph_module.graph.output_node().meta.get( + "original_dim_orders" + ) + output_node = graph_module.graph.output_node() + + if original_dim_orders is None: + raise RuntimeError( + f"{AnnotateOutputDimOrderPass.__name__} must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run." + ) + + if len(dim_orders) != len(original_dim_orders): + raise RuntimeError( + f"The number of outputs has changed since {AnnotateOutputDimOrderPass.__name__} was run." + ) + + for node, dim_order, original_dim_order in zip( + output_node.args[0], dim_orders, original_dim_orders + ): + if dim_order != original_dim_order: + raise RuntimeError( + f"The dim order of output {node.name} has changed from {original_dim_order} to {dim_order} since {AnnotateOutputDimOrderPass.__name__} was run." + ) diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index b62cc83ed8f..7b73cddad37 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -6,7 +6,7 @@ # pyre-unsafe from . import ( # noqa - clone_support, + clone_dim_order_support, convolution_support, embedding_support, ethos_u55_support, @@ -18,6 +18,6 @@ right_shift_support, sin_cos_support, slice_copy_support, - to_copy_support, + to_dim_order_copy_support, tosa_supported_operators, ) diff --git a/backends/arm/operator_support/clone_support.py b/backends/arm/operator_support/clone_dim_order_support.py similarity index 73% rename from backends/arm/operator_support/clone_support.py rename to backends/arm/operator_support/clone_dim_order_support.py index b3eea92fdad..1397b74bf38 100644 --- a/backends/arm/operator_support/clone_support.py +++ b/backends/arm/operator_support/clone_dim_order_support.py @@ -65,26 +65,4 @@ def is_node_tosa_supported( ) return False - # Check memory format - if "memory_format" in node.kwargs: - if node.kwargs["memory_format"] in (torch.preserve_format,): - self.reporter.report_reject( - node, - f"Argument 'memory_format' is not supported for " - f"{node.target} right now.", - ) - return False - - # Check dim_order - if "dim_order" in node.kwargs: - dim_order = node.kwargs["dim_order"] - # pyre-ignore[6] - if dim_order != list(range(len(dim_order))): # type: ignore[arg-type] - self.reporter.report_reject( - node, - f"Argument {dim_order=} is not supported for " - f"{node.target} right now.", - ) - return False - return True diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py similarity index 73% rename from backends/arm/operator_support/to_copy_support.py rename to backends/arm/operator_support/to_dim_order_copy_support.py index e357efe3f24..e21f8a68ad6 100644 --- a/backends/arm/operator_support/to_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -26,7 +26,6 @@ @register_tosa_support_check class ToCopySupported(SupportedTOSAOperatorCheck): targets = [ - exir_ops.edge.aten._to_copy.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] @@ -49,16 +48,16 @@ def _merge_supported_types( return merged_dtypes SUPPORTED_INT_TYPES: SupportedTypeDict = { - torch.bool: [torch.int8, torch.int16, torch.int32], - torch.int8: [torch.bool, torch.int16, torch.int32], - torch.int16: [torch.bool, torch.int8, torch.int32], - torch.int32: [torch.bool, torch.int8, torch.int16], + torch.bool: [torch.bool, torch.int8, torch.int16, torch.int32], + torch.int8: [torch.bool, torch.int8, torch.int16, torch.int32], + torch.int16: [torch.bool, torch.int8, torch.int16, torch.int32], + torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32], torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32], } SUPPORTED_FLOAT_TYPES: SupportedTypeDict = { - torch.int8: [torch.float16, torch.bfloat16, torch.float32], - torch.int16: [torch.float16, torch.bfloat16, torch.float32], - torch.int32: [torch.float16, torch.bfloat16, torch.float32], + torch.int8: [torch.int8, torch.float16, torch.bfloat16, torch.float32], + torch.int16: [torch.int16, torch.float16, torch.bfloat16, torch.float32], + torch.int32: [torch.int32, torch.float16, torch.bfloat16, torch.float32], # INT64 inputs to casts *should* be ok, since they should be rejected by # CheckInt64InputsAndOutputs if the cast can't be done AOT. torch.int64: [ @@ -69,9 +68,22 @@ def _merge_supported_types( torch.bfloat16, torch.float32, ], - torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32], - torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32], + torch.bfloat16: [ + torch.int8, + torch.int16, + torch.int32, + torch.bfloat16, + torch.float32, + ], + torch.float16: [ + torch.int8, + torch.int16, + torch.int32, + torch.float16, + torch.float32, + ], torch.float32: [ + torch.float32, torch.int8, torch.int16, torch.int32, @@ -149,30 +161,4 @@ def is_node_tosa_supported( ) return False - # Check memory format (to_copy) - if "memory_format" in node.kwargs: - if node.kwargs["memory_format"] in (torch.preserve_format,): - self.reporter.report_reject( - node, - ( - "Argument 'memory_format' is not supported for " - f"{node.target} right now." - ), - ) - return False - - # Check dim_order (to_dim_order_copy) - if "dim_order" in node.kwargs: - dim_order = node.kwargs["dim_order"] - # pyre-ignore[6] - if dim_order is not None and dim_order != list(range(len(dim_order))): # type: ignore[arg-type] - self.reporter.report_reject( - node, - ( - f"Argument {dim_order=} is not supported for " - f"{node.target} right now." - ), - ) - return False - return True diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index aed65bda812..f7a9638254e 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -51,7 +51,6 @@ op_sum, op_table, op_tanh, - op_to_copy, op_to_dim_order_copy, op_transpose, op_upsample_bilinear2d, diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py deleted file mode 100644 index dcf8fc64119..00000000000 --- a/backends/arm/operators/op_to_copy.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import Any, List - -import torch - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, -) -from executorch.backends.arm.tosa.mapping import TosaArg - - -@register_node_visitor -class ToCopyVisitor(NodeVisitor): - """ - Implement the type cast functionality of _to_copy. - - Other features like setting of the memory_format or moving a tensor to a - different device are not supported. - - Also note that the node should not be quantized. - """ - - target = "aten._to_copy.default" - - tosa_specs = NodeVisitor.tosa_specs - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 1) - - self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().CAST, [inputs[0].name], [output.name] - ) diff --git a/backends/arm/test/misc/test_dim_order_guards.py b/backends/arm/test/misc/test_dim_order_guards.py index b291aaa52cf..80a3c014abc 100644 --- a/backends/arm/test/misc/test_dim_order_guards.py +++ b/backends/arm/test/misc/test_dim_order_guards.py @@ -22,7 +22,7 @@ class Conv2D(torch.nn.Module): inputs: dict[str, input_t1] = { - "randn": (torch.randn(1, 2, 20, 20),), + "randn": (torch.randn(1, 2, 20, 20).to(memory_format=torch.channels_last),), } def __init__(self): @@ -30,7 +30,7 @@ def __init__(self): self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=(3, 3)) def forward(self, x): - return self.conv2d(x.to(memory_format=torch.channels_last)) + return self.conv2d(x) @common.parametrize("test_data", Conv2D.inputs) diff --git a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py index f89e06deda0..aa0f194590c 100644 --- a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py +++ b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py @@ -6,6 +6,7 @@ import unittest +import pytest import torch from executorch.backends.arm._passes import ( ConvertInt64ConstOpsToInt32Pass, @@ -32,11 +33,10 @@ class TestCLIPTextModelWithProjection(unittest.TestCase): # for that is some assert ops are removed by passes in the # .to_executorch step, i.e. after Arm partitioner. ops_after_partitioner = { - "executorch_exir_dialects_edge__ops_aten__to_copy_default": 4, - "executorch_exir_dialects_edge__ops_aten_argmax_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3, "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, - "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, - "torch.ops.higher_order.executorch_call_delegate": 2, + "executorch_exir_dialects_edge__ops_aten_argmax_default": 1, + "torch.ops.higher_order.executorch_call_delegate": 1, } def _prepare_inputs( @@ -86,9 +86,7 @@ def test_CLIPTextModelWithProjection_tosa_FP(self): ) ) - # MLETORCH-867, MLETORCH-1059 - # Failures: "Fatal Python error: Aborted, Dependency cycles, KeyError in CastInt64BuffersToInt32Pass") - @unittest.expectedFailure + @pytest.mark.xfail(raises=AssertionError, reason="Output difference.") def test_CLIPTextModelWithProjection_tosa_INT(self): text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs() with torch.no_grad(): diff --git a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py index 953f14e0b57..7c1a45f27cb 100644 --- a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py +++ b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py @@ -32,10 +32,9 @@ class TestT5EncoderModel(unittest.TestCase): # for that is some assert ops are removed by passes in the # .to_executorch step, i.e. after Arm partitioner. ops_after_partitioner = { - "executorch_exir_dialects_edge__ops_aten__to_copy_default": 2, - "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, - "torch.ops.higher_order.executorch_call_delegate": 3, + "torch.ops.higher_order.executorch_call_delegate": 2, } def _prepare_inputs( diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index 21197aa14fa..5c01788c805 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -15,6 +15,7 @@ from executorch.backends.arm.test.tester.test_pipeline import ( OpNotSupportedPipeline, TosaPipelineFP, + TosaPipelineINT, VgfPipeline, ) @@ -30,6 +31,15 @@ def forward(self, x: torch.Tensor): return x.to(dtype=self.target_dtype) +class CastAdd(torch.nn.Module): + def __init__(self, target_dtype): + super().__init__() + self.target_dtype = target_dtype + + def forward(self, x: torch.Tensor): + return x.to(dtype=self.target_dtype) + x.to(dtype=self.target_dtype) + + """ Tests the _to_copy operation. @@ -61,7 +71,7 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", _TO_COPY_TEST_DATA_FP) -def test_copy_tosa_FP(test_data: Tuple): +def test_to_tosa_FP(test_data: Tuple): test_tensor, new_dtype = test_data() pipeline = TosaPipelineFP[input_t1]( @@ -85,7 +95,7 @@ def test_copy_tosa_FP(test_data: Tuple): @common.parametrize("test_data", _TO_COPY_TEST_DATA_FP) @common.SkipIfNoModelConverter -def test_copy_vgf_FP(test_data: Tuple): +def test_to_vgf_FP(test_data: Tuple): test_tensor, new_dtype = test_data() pipeline = VgfPipeline[input_t1]( Cast(new_dtype), @@ -138,7 +148,7 @@ def test_copy_vgf_FP(test_data: Tuple): @common.parametrize("test_data", _TO_COPY_TEST_DATA_INT) -def test_copy_tosa_INT(test_data: Tuple): +def test_to_tosa_INT_not_delegated(test_data: Tuple): test_tensor, new_dtype = test_data() pipeline = OpNotSupportedPipeline[input_t1]( @@ -154,6 +164,83 @@ def test_copy_tosa_INT(test_data: Tuple): @common.parametrize("test_data", _TO_COPY_TEST_DATA_INT) @common.SkipIfNoModelConverter -def test_copy_vgf_INT(test_data: Tuple): +def test_to_vgf_INT(test_data: Tuple): # Op not supported pass + + +_TO_COPY_TEST_DATA_REDUNDANT_CAST = { + "rand_fp16_fp16": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.float16), + torch.float16, + ), + "rand_fp32_fp32": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.float32), + torch.float32, + ), + "rand_int8_int8": lambda: ( + torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), + torch.int8, + ), + "rand_int16_int16": lambda: ( + torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int16), + torch.int16, + ), + "rand_int32_int32": lambda: ( + torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), + torch.int32, + ), +} + +redundant_xfails_FP = { + "rand_fp16_fp16": "FP16 is not supported", + "rand_int8_int8": "Tracing graph with quantized input is not supported.", + "rand_int16_int16": "Tracing graph with quantized input is not supported.", +} + +redundant_xfails_INT = { + "rand_fp16_fp16": "FP16 is not supported", + "rand_int8_int8": "Tracing graph with quantized input is not supported.", +} + + +@common.parametrize( + "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_FP +) +def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple): + test_tensor, new_dtype = test_data() + pipeline = TosaPipelineFP[input_t1]( + CastAdd(new_dtype), + (test_tensor,), + aten_op=[], + exir_op=[], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +@common.parametrize( + "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_INT +) +def test_to_tosa_INT_REDUNDANT_CAST(test_data: Tuple): + test_tensor, new_dtype = test_data() + pipeline = TosaPipelineINT[input_t1]( + CastAdd(new_dtype), + (test_tensor,), + aten_op=[], + exir_op=[], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +@common.parametrize("test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST) +def test_to_tosa_INT_not_delegated_REDUNDANT_CAST(test_data: Tuple): + test_tensor, new_dtype = test_data() + pipeline = OpNotSupportedPipeline[input_t1]( + Cast(new_dtype), + (test_tensor,), + non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_insert_int64_to_int32_cast_pass.py b/backends/arm/test/passes/test_insert_int64_to_int32_cast_pass.py index da6eeb59459..6125e9b01cc 100644 --- a/backends/arm/test/passes/test_insert_int64_to_int32_cast_pass.py +++ b/backends/arm/test/passes/test_insert_int64_to_int32_cast_pass.py @@ -31,7 +31,7 @@ def test_int64_model_tosa_FP(): "executorch_exir_dialects_edge__ops_aten_embedding_default": 1, } op_checks_after = { - "executorch_exir_dialects_edge__ops_aten__to_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, "executorch_exir_dialects_edge__ops_aten_embedding_default": 1, } diff --git a/backends/arm/test/passes/test_remove_clone_pass.py b/backends/arm/test/passes/test_remove_clone_pass.py index 5c2171795f7..180f9c2ffe5 100755 --- a/backends/arm/test/passes/test_remove_clone_pass.py +++ b/backends/arm/test/passes/test_remove_clone_pass.py @@ -6,7 +6,7 @@ from typing import Tuple import torch -from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass +from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -15,7 +15,7 @@ class Clone(torch.nn.Module): """ - Basic remove layer model to test RemoveClonePass + Basic remove layer model to test RemoveNoopePass """ def __init__(self): @@ -40,6 +40,6 @@ def test_remove_clone_tosa_INT(): ops_not_after_pass=[ "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" ], - pass_list=[RemoveClonePass], + pass_list=[RemoveNoopPass], ) pipeline.run() diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 898ea588666..c0f546fe50a 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -35,13 +35,20 @@ def is_noop_clone(node: torch.fx.node.Node) -> bool: - return node.target == exir_ops.edge.aten.clone.default + return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default def is_noop_alias_copy(node: torch.fx.node.Node) -> bool: return node.target == exir_ops.edge.aten.alias_copy.default +def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: + if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default: + return False + else: + return node.meta.get("dtype") == get_first_fake_tensor(node.args[0]).dtype # type: ignore[arg-type] + + def is_noop_expand(node: torch.fx.node.Node) -> bool: if node.target != exir_ops.edge.aten.expand_copy.default: return False @@ -145,6 +152,7 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: is_noop_clone(node) or is_noop_alias_copy(node) or is_noop_expand(node) + or is_noop_to_dim_order_copy(node) or node.target in Q_OPS or node.target in DQ_OPS for node in partition.nodes