diff --git a/backends/transforms/test/test_remove_clone_ops.py b/backends/transforms/test/test_remove_clone_ops.py index d34c522baaa..c14ce4b4ad6 100644 --- a/backends/transforms/test/test_remove_clone_ops.py +++ b/backends/transforms/test/test_remove_clone_ops.py @@ -176,10 +176,13 @@ def test_clone_identity_removed(self): exported, compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), ) + # Early exit if clone_op_str is not present in the graph + # This is because some other pass may have removed the clone op. + before_gm_code = before_epm.exported_program().graph_module.code + if clone_op_str not in before_gm_code: + continue - FileCheck().check_count(clone_op_str, 1, exactly=True).run( - before_epm.exported_program().graph_module.code - ) + FileCheck().check_count(clone_op_str, 1, exactly=True).run(before_gm_code) updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) diff --git a/exir/passes/remove_noop_pass.py b/exir/passes/remove_noop_pass.py index d9b99556636..7f8d167c6f1 100644 --- a/exir/passes/remove_noop_pass.py +++ b/exir/passes/remove_noop_pass.py @@ -104,7 +104,10 @@ def call(self, graph_module: GraphModule) -> PassResult: if node.op != "call_function": continue - if node.target not in (torch.ops.aten._to_copy.default,): + if node.target not in ( + torch.ops.aten._to_copy.default, + torch.ops.aten.clone.default, + ): continue orig_tensor = node.args[0].meta["val"]