From 7efd19f0035015ce524569250dccb032ff81b94a Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 25 Jun 2025 11:09:32 -0700 Subject: [PATCH] Reinplace.py (#11918) Summary: Pass attempts to reinplace index_put if it is safe to do so. Reviewed By: angelayi Differential Revision: D77204122 --- exir/passes/TARGETS | 12 ++++ exir/passes/reinplace.py | 103 +++++++++++++++++++++++++++++ exir/tests/TARGETS | 12 ++++ exir/tests/test_reinplace_pass.py | 104 ++++++++++++++++++++++++++++++ 4 files changed, 231 insertions(+) create mode 100644 exir/passes/reinplace.py create mode 100644 exir/tests/test_reinplace_pass.py diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index 749e8f5c2f1..8699fe2fd02 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -31,6 +31,7 @@ python_library( ":sym_shape_eval_pass", ":sym_to_tensor_pass", ":weights_to_outputs_pass", + ":reinplace_pass", "//caffe2:torch", "//executorch/exir:common", "//executorch/exir:control_flow", @@ -68,6 +69,17 @@ python_library( ], ) +python_library( + name = "reinplace_pass", + srcs = [ + "reinplace.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ], +) + python_library( name = "insert_write_back_for_buffers_pass", srcs = [ diff --git a/exir/passes/reinplace.py b/exir/passes/reinplace.py new file mode 100644 index 00000000000..349869a2f4b --- /dev/null +++ b/exir/passes/reinplace.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set + +import torch +from executorch.exir.dialects._ops import ops +from torch.export import ExportedProgram + + +def _is_index_put(node: torch.fx.Node) -> bool: + """Check if a node is an index_put operation.""" + return node.op == "call_function" and node.target in ( + torch.ops.aten.index_put.default, + ops.edge.aten.index_put.default, + ) + + +def _is_safe_to_reinplace( + node: torch.fx.Node, + later_nodes: Set[torch.fx.Node], + inputs: Set[torch.fx.Node], + mutable_inputs: Set[torch.fx.Node], +) -> bool: + # This node is used later in the graph so we can't reinplace it + # There is probably a faster way to do this but this works for now. + if node in later_nodes: + return False + # If its not an input then we can reinplace it + if node not in inputs: + return True + # If its a mutable input then we can reinplace it + elif node in mutable_inputs: + return True + else: # input but not mutable input + return False + + +def _is_mutable_user_input( + node: torch.fx.Node, exported_program: ExportedProgram +) -> bool: + return ( + node.target in exported_program.graph_signature.user_inputs_to_mutate.values() + ) + + +def _is_mutable_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -> bool: + if node.target not in exported_program.graph_signature.inputs_to_buffers: + return False + buf = exported_program.graph_signature.inputs_to_buffers[node.target] + return buf in exported_program.graph_signature.buffers_to_mutate.values() + + +def reinplace_pass(ep: ExportedProgram) -> ExportedProgram: + """ + Pass that loops over nodes in an exported program and collects the first argument + of every call_function node that is a view_copy operation. + + Args: + exported_program: The ExportedProgram to analyze + + Returns: + Set of nodes that are first arguments to view_copy operations + """ + seen_nodes: Set[torch.fx.Node] = set() + # Get all placeholders + inputs = set() + for node in ep.graph.nodes: + if node.op == "placeholder": + inputs.add(node) + # Get all inputs that we could potentially mutate + mutable_nodes = set( + [ + node + for node in inputs + if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep) + ] + ) + + results = set() + for node in reversed(ep.graph.nodes): + if _is_index_put(node): + # Check if this index_put node is safe to inplace + # The first argument is the base tensor being indexed into + first_arg = node.args[0] + if _is_safe_to_reinplace(first_arg, seen_nodes, inputs, mutable_nodes): + # This index_put is safe to reinplace + with ep.graph.inserting_before(node): + new_node = ep.graph.call_function( + ops.edge.aten.index_put_.default, args=node.args + ) + new_node.meta["val"] = node.meta["val"] + node.replace_all_uses_with(new_node) + ep.graph.erase_node(node) + results.add(first_arg) + elif node.op == "call_function": + seen_nodes.update(node.all_input_nodes) + return ep diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 1423984c563..2c2ad3e05f0 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -136,6 +136,18 @@ python_unittest( ], ) +python_unittest( + name = "reinplace_pass", + srcs = [ + "test_reinplace_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir/passes:lib", + ], +) + cpp_library( name = "test_lib", srcs = [ diff --git a/exir/tests/test_reinplace_pass.py b/exir/tests/test_reinplace_pass.py new file mode 100644 index 00000000000..2f4538770d6 --- /dev/null +++ b/exir/tests/test_reinplace_pass.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from executorch.exir import to_edge +from executorch.exir.passes.reinplace import reinplace_pass +from torch.export import export + + +class TestReinplacePass(unittest.TestCase): + def test_index_put_reinplace(self) -> None: + """Test that index_put on a mutable buffer can be reinplaced.""" + + class IndexPutModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.zeros(5)) + + def forward( + self, indices: torch.Tensor, values: torch.Tensor + ) -> torch.Tensor: + # index_put on buffer (non-user input) should be safe + self.state.index_put_((indices,), values) + return self.state + + model = IndexPutModel() + indices = torch.tensor([0]) + values = torch.tensor([1.0]) + + exported_program = export(model, (indices, values), strict=True) + print(exported_program.graph) + edge_program = to_edge(exported_program).exported_program() + + # Find the index_put node + index_put_node = None + for node in edge_program.graph.nodes: + if node.op == "call_function" and "index_put" in str(node.target): + index_put_node = node + break + + self.assertIsNotNone(index_put_node, "Should find an index_put node") + + ep = reinplace_pass(edge_program) + # Find the index_put node + index_put_node = None + for node in ep.graph.nodes: + if node.op == "call_function" and "index_put_" in str(node.target): + index_put_node = node + break + + self.assertIsNotNone(index_put_node, "Should find an index_put_ node") + + def test_cant_reinplace(self) -> None: + """Test that index_put on a mutable buffer that is viewed later is not safe.""" + + class IndexPutModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.zeros(5)) + + def forward( + self, indices: torch.Tensor, values: torch.Tensor + ) -> torch.Tensor: + # index_put on buffer (non-user input) should be safe + x = self.state.index_put((indices,), values) + self.state.add_(1) + return x + + model = IndexPutModel() + indices = torch.tensor([0]) + values = torch.tensor([1.0]) + + exported_program = export(model, (indices, values), strict=True) + edge_program = to_edge(exported_program).exported_program() + + # Find the index_put node + index_put_node = None + for node in edge_program.graph.nodes: + if node.op == "call_function" and "index_put" in str(node.target): + index_put_node = node + break + + self.assertIsNotNone(index_put_node, "Should find an index_put node") + + ep = reinplace_pass(edge_program) + # Find the index_put node + index_put_node = None + for node in ep.graph.nodes: + if ( + node.op == "call_function" + and "index_put" in str(node.target) + and "index_put_" not in str(node.target) + ): + index_put_node = node + break + + self.assertIsNotNone(index_put_node, "Should still find an index_put node")