diff --git a/exir/passes/insert_write_back_for_buffers_pass.py b/exir/passes/insert_write_back_for_buffers_pass.py index d9a2acafb42..5ac5f49f2c4 100644 --- a/exir/passes/insert_write_back_for_buffers_pass.py +++ b/exir/passes/insert_write_back_for_buffers_pass.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple import torch +from executorch.exir.operator.convert import is_inplace_variant from torch.export.exported_program import ( ExportedProgram, @@ -17,6 +18,7 @@ ) from torch.export.graph_signature import TensorArgument from torch.utils import _pytree as pytree +from torchgen.model import SchemaKind def _insert_copy( @@ -70,6 +72,44 @@ def _insert_copy( return buffer_output_nodes +def _is_inplace_node(node: torch.fx.Node) -> bool: + """Check if a node is an inplace node.""" + return ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and is_inplace_variant( + node.target._schema.name, node.target._schema.overload_name + ) + ) + + +def _inplace_lineage( + output_arg: torch.fx.Node, + gm: torch.fx.GraphModule, + gs: ExportGraphSignature, + kind: SchemaKind, +) -> bool: + """ + Walk the graph backwards to see if output_arg is ultimately the same as an input. + """ + if kind != OutputKind.BUFFER_MUTATION and kind != OutputKind.USER_INPUT_MUTATION: + return False + + while output_arg.op != "placeholder": + if _is_inplace_node(output_arg): + # From looking at native_functions.yaml, inplace ops always have self as the first arg + output_arg = output_arg.args[0] # pyre-ignore + else: + return False + + # If the output arg was a buffer then it needs to reach a buffer placeholder + if kind == OutputKind.BUFFER_MUTATION: + return output_arg.target in gs.inputs_to_buffers + # If the output arg was a user input then it needs to reach a user input placeholder + assert kind == OutputKind.USER_INPUT_MUTATION + return output_arg.target in gs.user_inputs + + def insert_write_back_for_buffers_pass( ep: ExportedProgram, ) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]: @@ -99,9 +139,15 @@ def insert_write_back_for_buffers_pass( if lifted_node is not None: input_name_to_node[lifted_node] = input_node + output_node = None + for node in gm.graph.nodes: + if node.op == "output": + output_node = node + break + # Grab the mutable buffer nodes in the outputs, mutated_outputs: List[Optional[str]] = [] - for out_spec in ep.graph_signature.output_specs: + for i, out_spec in enumerate(ep.graph_signature.output_specs): # if the output arg is the input value then all operations on it are in-place # so there's no need to add a copy_ node if ( @@ -112,7 +158,12 @@ def insert_write_back_for_buffers_pass( out_spec.target in input_name_to_node and # if the arg and target are not the same, we add a copy_ node. - out_spec.arg.name != input_name_to_node[out_spec.target].name + not _inplace_lineage( + output_node.args[0][i], + gm, + ep.graph_signature, + ep.graph_signature.output_specs[i].kind, + ) ): mutated_outputs.append(out_spec.target) else: diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 2c2ad3e05f0..27256b3b3b4 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -145,6 +145,7 @@ python_unittest( "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir/passes:lib", + "//executorch/extension/pybindings:portable_lib", ], )