From c6bb918d2fe0d8128588eba272361e19aec659c9 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Fri, 29 Aug 2025 10:05:05 -0500 Subject: [PATCH] Add create_mutable_buffer util Rationale: * We want to support stateful custom ops inserted after to_edge in a non delegated cases like op-lib-backends i.e. cortex-m::cmsis_nn::linear Details: * Allows a pass to add a mutable buffer. * Essentially redoes what export, to_edge does for a buffer registered in an nn.module. * Also plays nice with to_executorch passes like conversion to in-place, and write-back for mutated buffers. * Verifies above with tests, adds a HelperPass for someone looking to leverage this util as an example too. To Test: $ python -m unittest backends.transforms.test.test_create_mutable_buffer --- .../test/test_create_mutable_buffer.py | 503 ++++++++++++++++++ backends/transforms/utils.py | 210 +++++++- 2 files changed, 702 insertions(+), 11 deletions(-) create mode 100644 backends/transforms/test/test_create_mutable_buffer.py diff --git a/backends/transforms/test/test_create_mutable_buffer.py b/backends/transforms/test/test_create_mutable_buffer.py new file mode 100644 index 00000000000..4eb4d9538f2 --- /dev/null +++ b/backends/transforms/test/test_create_mutable_buffer.py @@ -0,0 +1,503 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import executorch +import torch +from executorch.backends.transforms.utils import create_mutable_buffer +from executorch.exir import ExecutorchBackendConfig, to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.extension.pybindings.portable_lib import ( # @manual + _load_for_executorch_from_buffer, + Verification, +) +from torch.export import export +from torch.utils._pytree import tree_flatten + + +class TestMutableBufferCreation(unittest.TestCase): + """ + Test suite for the create_mutable_buffer utility function. + """ + + def test_create_mutable_buffer(self): + """ + Tests the utility function create_mutable_buffer which creates a mutable buffer + that can be modified during execution. + """ + + class EmptyNetwork(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + test_data: torch.Tensor = (torch.zeros(1),) + + module = EmptyNetwork() + exported_program = export(module, args=module.test_data, strict=True) + exported_program = to_edge(exported_program).exported_program() + graph = exported_program.graph_module.graph + + assert len(graph.nodes) == 2 + assert exported_program.module()(torch.zeros(1)) == 0 + assert len(exported_program.graph_signature.input_specs) == 1 + assert len(exported_program.graph_signature.output_specs) == 1 + assert len(exported_program.state_dict) == 0 + + buffer_name = "b_test_mutable_buffer" + target_name = buffer_name[2:] # Remove the "b_" prefix + initial_data = torch.ones(1) * 2 # Initialize with value 2 + + # Create a mutable buffer using create_mutable_buffer + buffer_node = create_mutable_buffer( + exp_program=exported_program, + name=buffer_name, + data=initial_data, + ) + assert "val" in buffer_node.meta + + # Verify the buffer was created correctly + input_node = list(graph.nodes)[ + 1 + ] # Original input node (buffer_node is now first) + + # Create an add operation that uses the mutable buffer + with graph.inserting_after(input_node): + graph.create_node( + "call_function", + exir_ops.edge.aten.add.Tensor, + args=(input_node, buffer_node), + kwargs={}, + ) + + # We should now have four nodes: buffer, input, add, output + assert len(graph.nodes) == 4 + + assert target_name in exported_program.state_dict + assert torch.equal(exported_program.state_dict[target_name], initial_data) + + # Check that buffer is properly referenced in graph signature + assert buffer_name in exported_program.graph_signature.inputs_to_buffers + assert ( + exported_program.graph_signature.inputs_to_buffers[buffer_name] + == target_name + ) + assert ( + exported_program.graph_signature.buffers_to_mutate[buffer_name] + == target_name + ) + + +class TestRegisterMutableBufferPass(unittest.TestCase): + """ + This test is to test the `create_mutable_buffer` utility. + """ + + class HelperPass(ExportPass): + def __init__( + self, + exported_program: torch.export.ExportedProgram, + buf_data: torch.Tensor, + ): + super().__init__() + self.registered_buffers = set() + self.exported_program = exported_program + self.buf_data = buf_data + + def call(self, graph_module: torch.fx.GraphModule): + """ + This pass will register a mutable buffer for the add op(s) in the graph. + It will insert a new index_put_ op to the graph to update the buffer using the output of the add op. + And adjust the output of the graph to return the buffer or the index_put_ op. + """ + modified = False + + assert ( + len(graph_module.graph.output_node().args[0]) == 1 + ), "For this pass, expecting only one output i.e. add" + + for node in graph_module.graph.nodes: + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.add.Tensor + ): + # To match what we do in export. + suffix = ( + "" + if len(self.registered_buffers) == 0 + else f"{len(self.registered_buffers) + 1}" + ) + buffer_name = f"b_my_buffer{suffix}" + self.registered_buffers.add(buffer_name) + + # Utility under test! + buf_node = create_mutable_buffer( + self.exported_program, + data=self.buf_data, + name=buffer_name, + ) + + # Assuming `indices` is always available + indices = [ + node + for node in graph_module.graph.nodes + if node.op == "placeholder" and node.name == "indices" + ] + with graph_module.graph.inserting_after(node): + index_put_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.index_put_.default, + ( + buf_node, + indices, + node, + ), + {}, + ) + + # Replace the old add output with index_put_node + output_node = graph_module.graph.output_node() + outputs = list(output_node.args[0]) + if node in outputs: + outputs[-1] = ( + index_put_node # Replace the old add output with index_put_node + ) + output_node.args = (outputs,) + + # update the output node name in the graph signature + graph_signature = self.exported_program.graph_signature + graph_signature.replace_all_uses( + node.name, index_put_node.name + ) + self.exported_program._graph_signature = graph_signature + + modified = True + + if modified: + graph_module = super().call(graph_module).graph_module + graph_module.graph.lint() + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, modified) + + def _test_edge_pass(self, model, example_inputs, num_lifted_args=1): + exported = export(model, example_inputs) + + edge_program = to_edge( + exported, + # for torch._export.verifier.SpecViolationError: + # Operator torch._ops.aten.index_put_.default is not in Core ATen opset + compile_config=executorch.exir.EdgeCompileConfig(_check_ir_validity=False), + ) + # Tensor data for the new buffer(s) + buffer_tensor = torch.zeros(2) + + transformed_example_inputs = ( + *(buffer_tensor for i in range(num_lifted_args)), + *example_inputs, + ) + transformed_ep = edge_program.transform( + passes=[ + TestRegisterMutableBufferPass.HelperPass( + edge_program.exported_program(), buffer_tensor + ) + ] + ) + transformed_edge_gm = transformed_ep.exported_program().graph_module + # Make sure it works + transformed_edge_gm(*transformed_example_inputs) + + # Explicitly passing inplace_pass to make sure it works with our manually inserted buffer and index_put_ node(s). + executorch_program_manager = edge_program.to_executorch( + ExecutorchBackendConfig( + emit_mutable_buffer_names=True, run_reinplace_pass=True + ) + ) + return executorch_program_manager + + def _test_eager(self, model, example_inputs, num_lifted_args=1): + exported = export(model, example_inputs, strict=True) + + edge_program = executorch.exir.to_edge(exported) + edge_gm = edge_program.exported_program().graph_module + + # Adding buffer as an extra pos[0] arg + buffer_tensor = torch.zeros(2) + edge_example_inputs = ( + *(buffer_tensor for i in range(num_lifted_args)), + *example_inputs, + ) + # Make sure it works + edge_gm(*edge_example_inputs) + + executorch_program_manager = edge_program.to_executorch( + ExecutorchBackendConfig( + emit_mutable_buffer_names=True, run_reinplace_pass=True + ) + ) + return executorch_program_manager + + def _compare_outputs(self, et_ep1, et_ep2, example_inputs): + + def run(et_pm, inputs): + buffer = et_pm.buffer + inputs_flattened, _ = tree_flatten(inputs) + executorch_module = _load_for_executorch_from_buffer( + buffer, program_verification=Verification.Minimal + ) + executorch_output = copy.deepcopy( + executorch_module.run_method("forward", tuple(inputs_flattened)) + ) + return executorch_output + + # compare the outputs of the two programs + output1 = run(et_ep1, example_inputs) + output2 = run(et_ep2, example_inputs) + assert len(output1) == len(output2) + for o1, o2 in zip(output1, output2): + self.assertTrue(torch.allclose(o1, o2)) + + def _compare_ep_state_dict(self, et_ep1, et_ep2): + # compare the state dict of the two programs + state_dict1 = et_ep1.exported_program().state_dict + state_dict2 = et_ep2.exported_program().state_dict + self.assertEqual(len(state_dict1), len(state_dict2)) + # a bit fragile comparing the names, but the util tries to match the names i.e. `b_my_buffer` and `b_my_buffer2` + for k, _ in state_dict1.items(): + self.assertTrue( + k in state_dict2, f"{state_dict1.keys()} != {state_dict2.keys()} @ {k}" + ) + + def _compare_signatures(self, et_ep1, et_ep2): + # compare the graph signatures + def _input_spec_compare(input_spec1, input_spec2): + self.assertEqual( + input_spec1.kind, + input_spec2.kind, + f"{input_spec1.kind} != {input_spec2.kind}", + ) + self.assertEqual(input_spec1.arg, input_spec2.arg) + self.assertEqual( + input_spec1.target, + input_spec2.target, + f"{input_spec1.target} != {input_spec2.target}", + ) + self.assertEqual( + input_spec1.persistent, + input_spec2.persistent, + f"{input_spec1.persistent} != {input_spec2.persistent}", + ) + + def _output_spec_compare(output_spec1, output_spec2): + self.assertEqual( + output_spec1.kind, + output_spec2.kind, + f"{output_spec1.kind} != {output_spec2.kind}", + ) + # TODO: Look into why the output names are different, + # and not updated by the buffer_write_back_pass when the buffer + # is inserted via a pass and used by an inplace op. + # self.assertEqual(output_spec1.arg, output_spec2.arg) + self.assertEqual( + output_spec1.target, + output_spec2.target, + f"{output_spec1.target} != {output_spec2.target}", + ) + + graph_signature1 = et_ep1.exported_program().graph_signature + graph_signature2 = et_ep2.exported_program().graph_signature + + # compare input spec order, kind and targets + self.assertEqual( + len(graph_signature1.input_specs), len(graph_signature2.input_specs) + ) + for i1, i2 in zip(graph_signature1.input_specs, graph_signature2.input_specs): + _input_spec_compare(i1, i2) + + # compare output spec order, kind and targets + self.assertEqual( + len(graph_signature1.output_specs), len(graph_signature2.output_specs) + ) + for o1, o2 in zip(graph_signature1.output_specs, graph_signature2.output_specs): + _output_spec_compare(o1, o2) + + def _compare_plan_ops(self, et1, et2): + operators1 = et1.executorch_program.execution_plan[0].operators + operators2 = et2.executorch_program.execution_plan[0].operators + for op1, op2 in zip(operators1, operators2): + self.assertEqual(op1.name, op2.name, f"{op1.name} != {op2.name}") + self.assertEqual( + op1.overload, op2.overload, f"{op1.overload} != {op2.overload}" + ) + + def compare(self, et1, et2, example_inputs): + self._compare_signatures(et1, et2) + self._compare_ep_state_dict(et1, et2) + self._compare_plan_ops(et1, et2) + self._compare_outputs(et1, et2, example_inputs) + + def test_basic(self): + class CustomModuleGraph(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, indices): + output = x + x + return output + + class CustomModuleSrc(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("my_buffer", torch.zeros(2)) + + def forward(self, x, indices): + output = x + x + self.my_buffer.index_put_((indices,), output) + return output + + example_inputs = (torch.ones(2), torch.tensor([0, 1])) + + with torch.no_grad(): + graph_model = CustomModuleGraph().eval() + et_1 = self._test_edge_pass(graph_model, example_inputs) + + src_model = CustomModuleSrc().eval() + et_2 = self._test_eager(src_model, example_inputs) + + self.compare(et_1, et_2, example_inputs) + + def test_basic_with_param(self): + input_tensor = torch.ones(2) + + class CustomModuleGraph(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_parameter("my_param", torch.nn.Parameter(input_tensor)) + + def forward(self, x, indices): + output = x + self.my_param + return output + + class CustomModuleSrc(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("my_buffer", torch.zeros(2)) + self.register_parameter("my_param", torch.nn.Parameter(input_tensor)) + + def forward(self, x, indices): + output = x + self.my_param + self.my_buffer.index_put_((indices,), output) + return output + + example_inputs = (input_tensor, torch.tensor([0, 1])) + with torch.no_grad(): + graph_model = CustomModuleGraph().eval() + et1 = self._test_edge_pass(graph_model, example_inputs, num_lifted_args=2) + + src_model = CustomModuleSrc().eval() + et2 = self._test_eager(src_model, example_inputs, num_lifted_args=2) + + self.compare(et1, et2, example_inputs) + + def test_basic_with_constant(self): + input_tensor = torch.ones(2) * 2 + + class CustomModuleGraph(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.constant = input_tensor + + def forward(self, x, indices): + output = x + self.constant + return output + + class CustomModuleSrc(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("my_buffer", torch.zeros(2)) + self.constant = input_tensor + + def forward(self, x, indices): + output = x + self.constant + self.my_buffer.index_put_((indices,), output) + return output + + example_inputs = (torch.ones(2), torch.tensor([0, 1])) + with torch.no_grad(): + graph_model = CustomModuleGraph().eval() + et1 = self._test_edge_pass(graph_model, example_inputs, num_lifted_args=2) + + src_model = CustomModuleSrc().eval() + et2 = self._test_eager(src_model, example_inputs, num_lifted_args=2) + + self.compare(et1, et2, example_inputs) + + def test_single(self): + class CustomModuleGraph(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, indices): + output = x + x + return output + + class CustomModuleSrc(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("my_buffer", torch.zeros(2)) + + def forward(self, x, indices): + output = x + x + self.my_buffer.index_put_((indices,), output) + return self.my_buffer + + example_inputs = (torch.ones(2), torch.tensor([0, 1])) + + graph_model = CustomModuleGraph().eval() + et1 = self._test_edge_pass(graph_model, example_inputs) + + src_model = CustomModuleSrc().eval() + et2 = self._test_eager(src_model, example_inputs) + + self.compare(et1, et2, example_inputs) + + def test_double(self): + class CustomModuleGraph(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, indices): + output = x + x + output = output + x + return output + + class CustomModuleSrc(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("my_buffer", torch.zeros(2)) + self.register_buffer("my_buffer2", torch.zeros(2)) + + def forward(self, x, indices): + output = x + x + self.my_buffer.index_put_((indices,), output) + + output = output + x + self.my_buffer2.index_put_((indices,), output) + + return output + + example_inputs = (torch.ones(2), torch.tensor([0, 1])) + + graph_model = CustomModuleGraph().eval() + et1 = self._test_edge_pass(graph_model, example_inputs, num_lifted_args=2) + + src_model = CustomModuleSrc().eval() + et2 = self._test_eager(src_model, example_inputs, num_lifted_args=2) + + self.compare(et1, et2, example_inputs) diff --git a/backends/transforms/utils.py b/backends/transforms/utils.py index 4e451928ee4..ef19a937b0b 100644 --- a/backends/transforms/utils.py +++ b/backends/transforms/utils.py @@ -22,10 +22,43 @@ ExportGraphSignature, InputKind, InputSpec, + OutputKind, + OutputSpec, TensorArgument, ) +def _get_fake_tensor_mode(graph: torch.fx.Graph, data: torch.Tensor) -> torch.Tensor: + """ + Helper function to create a fake tensor using the fake_mode from existing nodes in the graph. + + Args: + graph: The graph to get fake_mode from + data: The tensor data to create fake tensor for + + Returns: + A fake tensor with the appropriate fake_mode + + Raises: + RuntimeError: If the graph has no nodes to extract fake_mode from + """ + nodes = list(graph.nodes) + if not nodes: + raise RuntimeError( + "Cannot create fake tensor: graph has no nodes to extract fake_mode from" + ) + + example_node = nodes[0] + if isinstance( + example_node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) + ): + example_fake_tensor = example_node.meta["val"][0] + else: + example_fake_tensor = example_node.meta["val"] + + return FakeTensorConverter().from_real_tensor(example_fake_tensor.fake_mode, t=data) + + def is_get_attr_node(node: torch.fx.Node) -> bool: """ Returns true if the given node is a get attr node for a tensor of the model @@ -98,17 +131,7 @@ def create_constant_placeholder( case _: raise RuntimeError("Can only create constant input nodes.") - # Create fake tensor using the same fake_mode as the other fake tensors in the graph - example_node = list(graph.nodes)[0] - if isinstance( - example_node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) - ): - example_fake_tensor = example_node.meta["val"][0] - else: - example_fake_tensor = example_node.meta["val"] - fake_tensor = FakeTensorConverter().from_real_tensor( - example_fake_tensor.fake_mode, t=data - ) + fake_tensor = _get_fake_tensor_mode(graph, data) # Create node node = graph.create_node(op="placeholder", name=name, target=name) @@ -187,3 +210,168 @@ def delete_constant_placeholder(exp_program: ExportedProgram, node: torch.fx.Nod # Remove node from graph node.graph.erase_node(node) + + +def _validate_graph_signature(exp_program: ExportedProgram): + """ + Validates that the graph signature is up to date with the graph. + """ + placeholders = [n for n in exp_program.graph.nodes if n.op == "placeholder"] + if len(placeholders) != len(exp_program.graph_signature.input_specs): + raise RuntimeError( + f"Graph has {len(placeholders)} placeholder nodes but signature has " + f"{len(exp_program.graph_signature.input_specs)} input specs" + ) + for node, input_spec in zip(placeholders, exp_program.graph_signature.input_specs): + if node.name != input_spec.arg.name: + raise RuntimeError( + f"Input node {node.name} does not match input spec {input_spec.arg.name}" + ) + outputs = exp_program.graph.output_node().args[0] + if len(outputs) != len(exp_program.graph_signature.output_specs): + raise RuntimeError( + f"Graph has {len(outputs)} output nodes but signature has " + f"{len(exp_program.graph_signature.output_specs)} output specs" + ) + for node, output_spec in zip(outputs, exp_program.graph_signature.output_specs): + if node.name != output_spec.arg.name: + raise RuntimeError( + f"Output node {node.name} does not match output spec {output_spec.arg.name}" + ) + + +def _spec_to_node( + exp_program: ExportedProgram, spec: InputSpec | OutputSpec +) -> torch.fx.Node: + """ + Converts an InputSpec or OutputSpec to its corresponding node in the graph. + """ + # Extract the argument name from the spec + if hasattr(spec, "arg") and hasattr(spec.arg, "name"): + arg_name = spec.arg.name + else: + raise RuntimeError(f"Invalid spec format: {spec}") + + # Find the corresponding node in the graph + for node in exp_program.graph.nodes: + if node.name == arg_name: + return node + + raise RuntimeError(f"Could not find node with name '{arg_name}' in the graph") + + +def create_mutable_buffer( + exp_program: ExportedProgram, + name: str, + data: torch.Tensor, +) -> torch.fx.Node: + """ + Creates and returns a mutable buffer placeholder node. This is similar to + create_constant_placeholder but specifically for creating mutable buffers that + can be modified during execution. + + The difference between this and create_constant_placeholder is that this doesn't + expect user to set the correct position for the placeholder node to be inserted, + it finds the correct position automatically. + + It also updates the graph outputs to include the mutable buffer. + + Args: + exp_program: The exported program to modify + name: The name for the new buffer node (should start with "b_" prefix by convention) + data: The initial tensor data for the buffer + + Returns: + The created placeholder node to be used in the graph + """ + # Input validation + if not name or not name.strip(): + raise ValueError("Buffer name cannot be empty") + + if not isinstance(data, torch.Tensor): + raise ValueError("Data must be a torch.Tensor") + + # Extract target name (remove "b_" prefix if present, following export convention) + if name.startswith("b_"): + target = name[2:] + else: + target = name + + # Check if target already exists + if target in exp_program.state_dict: + raise RuntimeError(f"Buffer target '{target}' already exists in state_dict") + + _validate_graph_signature(exp_program) + + persistent_buffer = True + exp_program.state_dict[target] = data + + graph = exp_program.graph_module.graph + + # Create fake tensor using helper function + fake_tensor = _get_fake_tensor_mode(graph, data) + + # Signature ordering is as follows: + # Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] + # ^^^^^^^ + # insert here (at the end of buffers) + # Outputs = [*mutated_inputs, *flattened_user_outputs] + # ^^^^^^^^^^^^^^^ + # insert here (at the end of mutated inputs) + + # Inputs + # Find const or user input node if any, and insert before it + node_index = 0 + node = None + + input_specs = exp_program.graph_signature.input_specs + if len(input_specs) == 0 or all( + spec.kind not in [InputKind.CONSTANT_TENSOR, InputKind.USER_INPUT] + for spec in input_specs + ): + # No const or user input nodes + node_index = len(input_specs) + node = graph.create_node(op="placeholder", name=name, target=name) + else: + # Find the first constant or user input node + for i, spec in enumerate(input_specs): + if spec.kind in [InputKind.CONSTANT_TENSOR, InputKind.USER_INPUT]: + node_index = i + with graph.inserting_before(_spec_to_node(exp_program, spec)): + node = graph.create_node(op="placeholder", name=name, target=name) + break + + assert node is not None, "node should be created at this point" + node.meta["val"] = fake_tensor + buffer_input_spec = InputSpec( + InputKind.BUFFER, TensorArgument(name), target, persistent_buffer + ) + input_specs.insert(node_index, buffer_input_spec) + + # Outputs + # Create output spec for the mutable buffer, and insert it at the beginning of output specs + user_output_indices = [ + i + for i, spec in enumerate(exp_program.graph_signature.output_specs) + if spec.kind == OutputKind.USER_OUTPUT + ] + + output_index = user_output_indices[0] if user_output_indices else 0 + + output_specs = exp_program.graph_signature.output_specs + mutation_output_spec = OutputSpec( + OutputKind.BUFFER_MUTATION, TensorArgument(name), target + ) + output_specs.insert(output_index, mutation_output_spec) + + # Update the outputs to include the mutable buffer + output_node = graph.output_node() + args = list(output_node.args[0]) + args.insert(output_index, node) + output_node.args = (args,) + + # Update graph signature in the exported program + new_graph_signature = ExportGraphSignature(input_specs, output_specs) + exp_program._graph_signature = new_graph_signature + + return node