From cafe8a363b0dec27d3ed103985dd7e81b31b6e0a Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Wed, 19 Nov 2025 16:31:41 -0800 Subject: [PATCH] Remove no-op clones in xnnpack (#15884) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/15884 Differential Revision: D87405074 Pulled By: GregoryComer --- backends/transforms/remove_clone_ops.py | 21 ++- backends/xnnpack/_passes/TARGETS | 1 + backends/xnnpack/_passes/__init__.py | 10 ++ backends/xnnpack/operators/__init__.py | 3 + backends/xnnpack/operators/op_clone.py | 60 ++++++++ backends/xnnpack/partition/config/__init__.py | 3 + .../partition/config/generic_node_configs.py | 22 +++ backends/xnnpack/runtime/XNNCompiler.cpp | 29 ++++ .../xnnpack/serialization/runtime_schema.fbs | 1 + backends/xnnpack/serialization/schema.fbs | 1 + .../serialization/xnnpack_graph_schema.py | 6 + backends/xnnpack/test/ops/test_clone.py | 137 ++++++++++++++++++ 12 files changed, 293 insertions(+), 1 deletion(-) create mode 100644 backends/xnnpack/operators/op_clone.py create mode 100644 backends/xnnpack/test/ops/test_clone.py diff --git a/backends/transforms/remove_clone_ops.py b/backends/transforms/remove_clone_ops.py index 01fe2ee26a4..79b93af8beb 100644 --- a/backends/transforms/remove_clone_ops.py +++ b/backends/transforms/remove_clone_ops.py @@ -25,8 +25,9 @@ class RemoveCloneOpsTransform(ExportPass): exir_ops.edge.dim_order_ops._clone_dim_order.default, } - def __init__(self) -> None: + def __init__(self, preserve_input_output_copies: bool = False) -> None: super().__init__() + self._preserve_input_output_copies = preserve_input_output_copies def _remove(self, graph_module: torch.fx.GraphModule) -> None: dequant_nodes = [] @@ -38,6 +39,11 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: if self._is_non_identity_clone(n): continue + # If preserve_input_output_copies is set, don't remove clones that directly + # copy from input to output. + if self._is_input_output_copy(n) and self._preserve_input_output_copies: + continue + to_be_removed = n for user_n in list(n.users.keys()): user_n.replace_input_with(n, n.args[0]) @@ -76,3 +82,16 @@ def _is_non_identity_clone(self, node: torch.fx.Node) -> bool: ) return False + + def _is_input_output_copy(self, node: torch.fx.Node) -> bool: + """Return True if the node input is a graph input and output goes into an output node.""" + + input_node = node.args[0] + if input_node.op != "placeholder": + return False + + for users in node.users: + if users.op == "output": + return True + + return False diff --git a/backends/xnnpack/_passes/TARGETS b/backends/xnnpack/_passes/TARGETS index 6f7b13d8026..4977ad08936 100644 --- a/backends/xnnpack/_passes/TARGETS +++ b/backends/xnnpack/_passes/TARGETS @@ -8,6 +8,7 @@ runtime.python_library( deps = [ "//caffe2:torch", "//executorch/backends/transforms:addmm_mm_to_linear", + "//executorch/backends/transforms:remove_clone_ops", "//executorch/backends/transforms:lib", "//executorch/backends/xnnpack/partition:partitioner_graphs", "//executorch/backends/xnnpack/serialization:xnnpack_schema", diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index c48896b3d81..4992d7a4abd 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -4,8 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import List, Optional, Type +from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform + from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ( @@ -42,6 +46,11 @@ from torch.export import ExportedProgram +class XNNPACKRemoveCloneOpsTransform(RemoveCloneOpsTransform): + def __init__(self): + super().__init__(preserve_input_output_copies=True) + + class XNNPACKPassManager: def __init__( self, @@ -58,6 +67,7 @@ def __init__( if not passes: # All the XNNPACK passes self.passes = [ + XNNPACKRemoveCloneOpsTransform, # TODO - remove this pass once we have a better support for dim_order ops lowering DimOrderOpsRevertPass, ConvertToUpsampleBilinear2d, diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index 93424b1c84d..02a46a6fc47 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from . import ( # noqa node_visitor, op_abs, @@ -14,6 +16,7 @@ op_cat, op_ceiling, op_clamp, + op_clone, op_conv2d, op_div, op_dynamic_dequantize_ops, diff --git a/backends/xnnpack/operators/op_clone.py b/backends/xnnpack/operators/op_clone.py new file mode 100644 index 00000000000..e4ddf187ecc --- /dev/null +++ b/backends/xnnpack/operators/op_clone.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNCopy, + XNNGraph, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class CloneVisitor(NodeVisitor): + target = "aten.clone.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # Sanity check that the input and output dim order are the same. We don't + # handle dim order conversions yet. + dim_order = node.kwargs.get("dim_order", None) + input_meta = node.args[0].meta["val"] + assert dim_order is None or list(input_meta.dim_order()) == dim_order + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNCopy( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 86baba3e3f7..5427b3a7838 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import List, Type @@ -22,6 +23,7 @@ CatConfig, CeilConfig, ClampConfig, + CloneDimOrderConfig, ConstantPadConfig, DeQuantizedPerTensorConfig, DivConfig, @@ -77,6 +79,7 @@ BMMConfig, CatConfig, CeilConfig, + CloneDimOrderConfig, ConstantPadConfig, ConvolutionConfig, ClampConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 06024c632c9..434fce1d73a 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -643,3 +643,25 @@ class SinConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] + + +class CloneDimOrderConfig(GenericNodePartitionerConfig): + target_name = "_clone_dim_order.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + + # Only partition no-op _clone_dim_order nodes (output dim order = input). + # We can relax this in the future. + # This is also a conservative check and doesn't consider ambiguity. + dim_order = node.kwargs.get("dim_order", None) + input_meta = node.args[0].meta["val"] + if dim_order is not None and list(input_meta.dim_order()) != dim_order: + why(node, reason="Only dim-order preserving clones are supported.") + return False + + return True diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 3e697566ce5..ec937a64744 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1459,6 +1459,34 @@ Error defineBatchMatrixMultiplyNode( return Error::Ok; } +/* + * Defines a copy node in the XNN subgraph. + */ +Error defineCopyNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNCopy(); + + xnn_status status = xnn_define_copy( + subgraph_ptr, + remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create copy node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Returns not Implemented Error code. This function is meant to be called when the compiler encountes a XNodeType from the flatbuffer @@ -1763,6 +1791,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Concatenate5) _DEFINE(StaticSlice) _DEFINE(BatchMatrixMultiply) + _DEFINE(Copy) case fb_xnnpack::XNodeUnion::NONE: default: // Adding here as a catch all, just in case return &defineNotImplementedNode; diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index 239f92d899e..939bbd7a82f 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -157,6 +157,7 @@ union XNodeUnion { XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, XNNSin: _XNNNode1x1, + XNNCopy: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 92a61c5537b..08d9184b9f5 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -153,6 +153,7 @@ union XNodeUnion { XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, XNNSin: _XNNNode1x1, + XNNCopy: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 2b3f8e74202..872056fa82e 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -352,6 +352,11 @@ class XNNSin(XNNNode1x1): pass +@dataclass +class XNNCopy(XNNNode1x1): + pass + + @dataclass class XNNScaledDotProductAttention: query_id: int @@ -409,6 +414,7 @@ class XNNScaledDotProductAttention: XNNTanh, XNNExp, XNNSin, + XNNCopy, ] diff --git a/backends/xnnpack/test/ops/test_clone.py b/backends/xnnpack/test/ops/test_clone.py new file mode 100644 index 00000000000..0396b9b2bea --- /dev/null +++ b/backends/xnnpack/test/ops/test_clone.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestClone(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Clone(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = torch.clone(x) + return z + + class CloneWithMemoryFormat(torch.nn.Module): + def __init__(self, memory_format): + super().__init__() + self.memory_format = memory_format + + def forward(self, x): + z = torch.clone(x, memory_format=self.memory_format) + return z + + def _test_clone_partitioned(self, inputs): + """Test that dim-order preserving clones are partitioned (removed)""" + ( + Tester(self.Clone(), inputs) + .export() + .check_count({"torch.ops.aten.clone.default": 1}) + .dump_artifact() + .to_edge_transform_and_lower() + .dump_artifact() + .check_not( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_clone(self): + """Test FP16 clone - should be partitioned""" + inputs = (torch.randn(2, 3, 4, 5).to(torch.float16),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone(self): + """Test FP32 clone - should be partitioned""" + inputs = (torch.randn(2, 3, 4, 5),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone_2d(self): + """Test FP32 clone with 2D tensor - should be partitioned""" + inputs = (torch.randn(10, 20),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone_3d(self): + """Test FP32 clone with 3D tensor - should be partitioned""" + inputs = (torch.randn(2, 3, 4),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone_with_contiguous_format(self): + """Test FP32 clone with contiguous memory format - should be partitioned""" + inputs = (torch.randn(1, 3, 4, 4),) + ( + Tester(self.CloneWithMemoryFormat(torch.contiguous_format), inputs) + .export() + .to_edge_transform_and_lower() + .dump_artifact() + .check_not( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_clone_with_channels_last_not_partitioned(self): + """Test FP32 clone with channels_last memory format - should NOT be partitioned""" + inputs = (torch.randn(1, 3, 4, 4),) + ( + Tester(self.CloneWithMemoryFormat(torch.channels_last), inputs) + .export() + .to_edge_transform_and_lower() + # Clone with channels_last changes dim order, so should NOT be delegated + .check( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_clone_channels_last_to_contiguous_not_partitioned(self): + """Test clone from channels_last to contiguous - should NOT be partitioned""" + + class CloneChannelsLastToContiguous(torch.nn.Module): + def forward(self, x): + # Start with channels_last input + y = x.to(memory_format=torch.channels_last) + # Clone back to contiguous (changes dim order) + z = torch.clone(y, memory_format=torch.contiguous_format) + return z + + inputs = (torch.randn(1, 3, 4, 4),) + ( + Tester(CloneChannelsLastToContiguous(), inputs) + .export() + .to_edge_transform_and_lower() + .dump_artifact() + # Clone that changes dim order should NOT be delegated + .check( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + )