Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions exir/passes/insert_write_back_for_buffers_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, List, Optional, Tuple

import torch
from executorch.exir.operator.convert import is_inplace_variant

from torch.export.exported_program import (
ExportedProgram,
Expand All @@ -17,6 +18,7 @@
)
from torch.export.graph_signature import TensorArgument
from torch.utils import _pytree as pytree
from torchgen.model import SchemaKind


def _insert_copy(
Expand Down Expand Up @@ -70,6 +72,44 @@ def _insert_copy(
return buffer_output_nodes


def _is_inplace_node(node: torch.fx.Node) -> bool:
"""Check if a node is an inplace node."""
return (
node.op == "call_function"
and isinstance(node.target, torch._ops.OpOverload)
and is_inplace_variant(
node.target._schema.name, node.target._schema.overload_name
)
)


def _inplace_lineage(
output_arg: torch.fx.Node,
gm: torch.fx.GraphModule,
gs: ExportGraphSignature,
kind: SchemaKind,
) -> bool:
"""
Walk the graph backwards to see if output_arg is ultimately the same as an input.
"""
if kind != OutputKind.BUFFER_MUTATION and kind != OutputKind.USER_INPUT_MUTATION:
return False

while output_arg.op != "placeholder":
if _is_inplace_node(output_arg):
# From looking at native_functions.yaml, inplace ops always have self as the first arg
output_arg = output_arg.args[0] # pyre-ignore
else:
return False

# If the output arg was a buffer then it needs to reach a buffer placeholder
if kind == OutputKind.BUFFER_MUTATION:
return output_arg.target in gs.inputs_to_buffers
# If the output arg was a user input then it needs to reach a user input placeholder
assert kind == OutputKind.USER_INPUT_MUTATION
return output_arg.target in gs.user_inputs


def insert_write_back_for_buffers_pass(
ep: ExportedProgram,
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
Expand Down Expand Up @@ -99,9 +139,15 @@ def insert_write_back_for_buffers_pass(
if lifted_node is not None:
input_name_to_node[lifted_node] = input_node

output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
break

# Grab the mutable buffer nodes in the outputs,
mutated_outputs: List[Optional[str]] = []
for out_spec in ep.graph_signature.output_specs:
for i, out_spec in enumerate(ep.graph_signature.output_specs):
# if the output arg is the input value then all operations on it are in-place
# so there's no need to add a copy_ node
if (
Expand All @@ -112,7 +158,12 @@ def insert_write_back_for_buffers_pass(
out_spec.target in input_name_to_node
and
# if the arg and target are not the same, we add a copy_ node.
out_spec.arg.name != input_name_to_node[out_spec.target].name
not _inplace_lineage(
output_node.args[0][i],
gm,
ep.graph_signature,
ep.graph_signature.output_specs[i].kind,
)
):
mutated_outputs.append(out_spec.target)
else:
Expand Down
1 change: 1 addition & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ python_unittest(
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/passes:lib",
"//executorch/extension/pybindings:portable_lib",
],
)

Expand Down
Loading