From d2db010627a730498efa0fc7e559038a5d8cadac Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 2 Jul 2025 10:44:48 -0700 Subject: [PATCH] use graph.output_node (#12139) Summary: This util got added a while back Reviewed By: angelayi Differential Revision: D77247219 --- backends/test/harness/tester.py | 7 +------ ...test_create_delete_constant_placeholder.py | 2 +- exir/backend/backend_api.py | 16 ++++---------- exir/emit/_emit_program.py | 6 +----- exir/lowered_backend_module.py | 21 ++++++------------- .../insert_write_back_for_buffers_pass.py | 12 ++--------- exir/passes/quantize_io_pass.py | 5 +---- exir/passes/weights_to_outputs_pass.py | 7 +------ exir/tests/test_joint_graph.py | 12 ++--------- 9 files changed, 19 insertions(+), 69 deletions(-) diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py index f1dfeb23531..fccea067d7d 100644 --- a/backends/test/harness/tester.py +++ b/backends/test/harness/tester.py @@ -416,12 +416,7 @@ def _calculate_reference_output( """ # Locate the output node. - output_node = None - for node in program.graph.nodes: - if node.op == "output": - output_node = node - break - assert output_node is not None + output_node = program.graph.output_node() # Look for a dequantization node in the output node args. Returned values are found in the first # argument of the output node. diff --git a/backends/transforms/test/test_create_delete_constant_placeholder.py b/backends/transforms/test/test_create_delete_constant_placeholder.py index ad24f8bfaaf..a095d561a7a 100644 --- a/backends/transforms/test/test_create_delete_constant_placeholder.py +++ b/backends/transforms/test/test_create_delete_constant_placeholder.py @@ -61,7 +61,7 @@ def _test_create_delete(kind: InputKind, persistent_buffer: bool = None): kwargs={}, ) - output_node = list(graph.nodes)[-1] + output_node = graph.output_node() output_node.replace_input_with(input_node, add_node) # We should now have four nodes: test_node, input, add, output diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 91df0409051..724bbf3fcf6 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -288,12 +288,8 @@ def _partition_and_lower_one_graph_module( tagged_graph_module, node_list, tag ) - tagged_graph_module_output_node = [ - node for node in tagged_graph_module.graph.nodes if node.op == "output" - ][0] - submodule_output_node = [ - node for node in submodule.graph.nodes if node.op == "output" - ][0] + tagged_graph_module_output_node = tagged_graph_module.graph.output_node() + submodule_output_node = submodule.graph.output_node() # Copy the output node meta from the original output node, because # create_submodule_from_nodes doesn't cover the meta field submodule_output_node.meta = tagged_graph_module_output_node.meta @@ -476,12 +472,8 @@ def _create_partitions_in_graph_module( tagged_graph_module, node_list, tag ) - tagged_graph_module_output_node = [ - node for node in tagged_graph_module.graph.nodes if node.op == "output" - ][0] - submodule_output_node = [ - node for node in submodule.graph.nodes if node.op == "output" - ][0] + tagged_graph_module_output_node = tagged_graph_module.graph.output_node() + submodule_output_node = submodule.graph.output_node() # Copy the output node meta from the original output node, because # create_submodule_from_nodes doesn't cover the meta field submodule_output_node.meta = tagged_graph_module_output_node.meta diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index f456626feed..cb849dde11a 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -57,11 +57,7 @@ class EmitterOutput: def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule: gm = exported_program.graph_module - output_node = None - for node in gm.graph.nodes: - if node.op == "output": - output_node = node - assert output_node is not None + output_node = gm.graph.output_node() mutated_outputs: List[Optional[str]] = [ out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index b2021d92a2a..e1dd7cb4079 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -233,14 +233,11 @@ def program( ) ] - output_node = [ - node for node in lowered_exported_program.graph.nodes if node.op == "output" - ] - assert len(output_node) == 1, "There should be only one output node" + output_node = lowered_exported_program.graph.output_node() # Step 1. Cleaning up the graph before inserting the call_delegate node # Remove the original output node - lowered_exported_program.graph.erase_node(output_node[0]) + lowered_exported_program.graph.erase_node(output_node) # Remove all the everything else except the input for node in reversed(lowered_exported_program.graph.nodes): @@ -269,11 +266,9 @@ def program( ) # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],) # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly - original_output_nodes = [ - node - for node in self._original_exported_program.graph.nodes - if node.op == "output" - ][0].args[0] + original_output_nodes = ( + self._original_exported_program.graph.output_node().args[0] + ) delegate_node.meta["spec"] = tuple( [make_spec(node.meta["val"]) for node in original_output_nodes] @@ -927,11 +922,7 @@ def _unsafe_adjust_original_program( # noqa: C901 raise RuntimeError(f"Invalid input spec {input_spec} received") # Delete buffer mutations from the output which were consumed by the delegate - toplevel_output_node = None - for node in reversed(original_program.graph.nodes): - if node.op == "output": - toplevel_output_node = node - break + toplevel_output_node = original_program.graph.output_node() assert toplevel_output_node is not None assert ( diff --git a/exir/passes/insert_write_back_for_buffers_pass.py b/exir/passes/insert_write_back_for_buffers_pass.py index 5ac5f49f2c4..167d489bfc9 100644 --- a/exir/passes/insert_write_back_for_buffers_pass.py +++ b/exir/passes/insert_write_back_for_buffers_pass.py @@ -30,11 +30,7 @@ def _insert_copy( Find the all the buffers and inputs that were mutated and insert copy_ operators to reflect mutations. """ - output_node = None - for node in gm.graph.nodes: - if node.op == "output": - output_node = node - break + output_node = gm.graph.output_node() assert output_node is not None outputs = pytree.tree_flatten(output_node.args)[0] assert len(outputs) == len(mutated_outputs) @@ -139,11 +135,7 @@ def insert_write_back_for_buffers_pass( if lifted_node is not None: input_name_to_node[lifted_node] = input_node - output_node = None - for node in gm.graph.nodes: - if node.op == "output": - output_node = node - break + output_node = gm.graph.output_node() # Grab the mutable buffer nodes in the outputs, mutated_outputs: List[Optional[str]] = [] diff --git a/exir/passes/quantize_io_pass.py b/exir/passes/quantize_io_pass.py index 095b07a1bf7..836a7376f7d 100644 --- a/exir/passes/quantize_io_pass.py +++ b/exir/passes/quantize_io_pass.py @@ -145,11 +145,8 @@ def quantize_output(exported_program, output_index): output quantization. """ graph = exported_program.graph_module.graph - outputs = [n for n in graph.nodes if n.op == "output"] - if len(outputs) != 1: - raise NotImplementedError("Only 1 output node is supported") - output_node = outputs[0] + output_node = graph.output_node() output_list = list(output_node.args[0]) if output_index >= len(output_list): raise ValueError( diff --git a/exir/passes/weights_to_outputs_pass.py b/exir/passes/weights_to_outputs_pass.py index aaf0c0eb5dc..c3e76d44f37 100644 --- a/exir/passes/weights_to_outputs_pass.py +++ b/exir/passes/weights_to_outputs_pass.py @@ -46,12 +46,7 @@ def weights_to_outputs_pass( inputs_to_params = gs.inputs_to_parameters # Get output node - output_node = None - for node in gm.graph.nodes: - if node.op == "output": - output_node = node - break - assert output_node is not None + output_node = gm.graph.output_node() # Get input nodes that are weights with an associated gradient placeholder_nodes = [ diff --git a/exir/tests/test_joint_graph.py b/exir/tests/test_joint_graph.py index fb74b70d313..1597d71e8db 100644 --- a/exir/tests/test_joint_graph.py +++ b/exir/tests/test_joint_graph.py @@ -42,11 +42,7 @@ def forward(self, x, y): joint_ep = _export_forward_backward(ep) edge = to_edge(joint_ep) - output_node = None - for node in edge.exported_program().graph.nodes: - if node.op == "output": - output_node = node - break + output_node = edge.exported_program().graph.output_node() orig_outputs = len(output_node.args[0]) @@ -58,11 +54,7 @@ def forward(self, x, y): if spec.kind == OutputKind.TOKEN ] - output_node = None - for node in et.exported_program().graph.nodes: - if node.op == "output": - output_node = node - break + output_node = et.exported_program().graph.output_node() weight_outputs = len(output_node.args[0])