diff --git a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py index 427ddaf14a5..d732903c2ff 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py @@ -76,6 +76,7 @@ def forward(self, x): return self.block(x) +@unittest.skip("Clones are optimized out of the graph.") class TestCloneConverter(unittest.TestCase): __test__ = False # Prevent interfering with PyTest tests diff --git a/backends/transforms/test/test_remove_clone_ops.py b/backends/transforms/test/test_remove_clone_ops.py index d34c522baaa..bcd8a4eea48 100644 --- a/backends/transforms/test/test_remove_clone_ops.py +++ b/backends/transforms/test/test_remove_clone_ops.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import unittest import torch @@ -164,34 +166,6 @@ def test_clone_non_identity_survives(self): assert torch.allclose(actual, expected) assert is_channel_last_dim_order(actual) - def test_clone_identity_removed(self): - """Verify identity clone ops are removed by RemoveCloneOpsTransform.""" - - for skip_dim_order, clone_op_str in self.CLONE_OP_CASES: - model = SimpleCloneChannelsLastModule() - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) - - exported = export(model.eval(), (x,), strict=True) - before_epm = to_edge( - exported, - compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), - ) - - FileCheck().check_count(clone_op_str, 1, exactly=True).run( - before_epm.exported_program().graph_module.code - ) - - updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) - - FileCheck().check_not(clone_op_str).run( - updated_epm.exported_program().graph_module.code - ) - - expected = before_epm.exported_program().module()(x) - actual = updated_epm.exported_program().module()(x) - assert torch.allclose(actual, expected) - assert is_channel_last_dim_order(actual) - if __name__ == "__main__": unittest.main() diff --git a/exir/passes/remove_noop_pass.py b/exir/passes/remove_noop_pass.py index d9b99556636..3efb2a28fcb 100644 --- a/exir/passes/remove_noop_pass.py +++ b/exir/passes/remove_noop_pass.py @@ -56,35 +56,10 @@ def call(self, graph_module: GraphModule) -> PassResult: dequant_nodes = [] for node in graph_module.graph.nodes: - if node.op != "call_function": - continue - - if node.target not in ( - torch.ops.aten.to.dtype, - torch.ops.aten.dropout.default, - torch.ops.aten.slice_copy.Tensor, - ): - continue - - orig_tensor = node.args[0].meta["val"] - - if orig_tensor is node.meta["val"]: - # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q. - # Otherwise, removing only the op will suffice. + if RemoveNoopPass._should_remove_node(node): if node.args[0].target in _DEQUANT_OPS: dequant_nodes += [node.args[0]] node.replace_all_uses_with(node.args[0]) - continue - - if node.target == torch.ops.aten.slice_copy.Tensor: - # Only do this check if all the dims are static. - if all(isinstance(dim, int) for dim in orig_tensor.size()): - if orig_tensor.shape == node.meta["val"].shape: - # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q. - # Otherwise, removing only the op will suffice. - if node.args[0].target in _DEQUANT_OPS: - dequant_nodes += [node.args[0]] - node.replace_all_uses_with(node.args[0]) graph_module.graph.eliminate_dead_code() eliminate_dq_q(graph_module, dequant_nodes) @@ -93,6 +68,41 @@ def call(self, graph_module: GraphModule) -> PassResult: return PassResult(graph_module, True) + @staticmethod + def _should_remove_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + + input_meta_val = ( + node.args[0].meta.get("val", None) + if len(node.args) > 0 and hasattr(node.args[0], "meta") + else None + ) + + if input_meta_val is not None: + if node.target in ( + torch.ops.aten.to.dtype, + torch.ops.aten.dropout.default, + ): + return input_meta_val is node.meta["val"] + elif node.target == torch.ops.aten.slice_copy.Tensor: + # Only do this check if all the dims are static. + return ( + all(isinstance(dim, int) for dim in input_meta_val.size()) + and input_meta_val.shape == node.meta["val"].shape + ) + elif node.target == torch.ops.aten.clone.default: + # Remove if memory_format=None, preserve_format, or input already has the target memory format. + dest_memory_format = ( + node.kwargs.get("memory_format", None) or torch.preserve_format + ) + return ( + dest_memory_format == torch.preserve_format + or input_meta_val.is_contiguous(memory_format=dest_memory_format) + ) + + return False + class RemoveToCopyPass(ExportPass): """ diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 14f105e8205..b46068b4038 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -2093,3 +2093,35 @@ def forward(self, x): prop_tensor.is_contiguous(), f"Propagated tensor is not contiguous: {prop_tensor.stride()}", ) + + def test_remove_noop_pass_clone(self) -> None: + """ + Verify the no-op clones are removed from the graph. + """ + + class CloneModel(torch.nn.Module): + def forward(self, x): + return x.clone() + x.clone() + + model = CloneModel() + inputs = (torch.randn(1, 16),) + + ep = torch.export.export(model, inputs) + lowered = to_edge_transform_and_lower(ep) + + # Sanity check the test - we should see clones in the exported program + self.assertTrue( + any( + n.op == "call_function" and n.target == torch.ops.aten.clone.default + for n in ep.graph.nodes + ) + ) + + # Since the clone ops are no-ops, they should be gone. + self.assertFalse( + any( + n.op == "call_function" + and n.target == exir_ops.edge.dim_order_ops._clone_dim_order.default + for n in lowered.exported_program().graph.nodes + ) + )