Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions backends/test/harness/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,12 +416,7 @@ def _calculate_reference_output(
"""

# Locate the output node.
output_node = None
for node in program.graph.nodes:
if node.op == "output":
output_node = node
break
assert output_node is not None
output_node = program.graph.output_node()

# Look for a dequantization node in the output node args. Returned values are found in the first
# argument of the output node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _test_create_delete(kind: InputKind, persistent_buffer: bool = None):
kwargs={},
)

output_node = list(graph.nodes)[-1]
output_node = graph.output_node()
output_node.replace_input_with(input_node, add_node)

# We should now have four nodes: test_node, input, add, output
Expand Down
16 changes: 4 additions & 12 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,8 @@ def _partition_and_lower_one_graph_module(
tagged_graph_module, node_list, tag
)

tagged_graph_module_output_node = [
node for node in tagged_graph_module.graph.nodes if node.op == "output"
][0]
submodule_output_node = [
node for node in submodule.graph.nodes if node.op == "output"
][0]
tagged_graph_module_output_node = tagged_graph_module.graph.output_node()
submodule_output_node = submodule.graph.output_node()
# Copy the output node meta from the original output node, because
# create_submodule_from_nodes doesn't cover the meta field
submodule_output_node.meta = tagged_graph_module_output_node.meta
Expand Down Expand Up @@ -476,12 +472,8 @@ def _create_partitions_in_graph_module(
tagged_graph_module, node_list, tag
)

tagged_graph_module_output_node = [
node for node in tagged_graph_module.graph.nodes if node.op == "output"
][0]
submodule_output_node = [
node for node in submodule.graph.nodes if node.op == "output"
][0]
tagged_graph_module_output_node = tagged_graph_module.graph.output_node()
submodule_output_node = submodule.graph.output_node()
# Copy the output node meta from the original output node, because
# create_submodule_from_nodes doesn't cover the meta field
submodule_output_node.meta = tagged_graph_module_output_node.meta
Expand Down
6 changes: 1 addition & 5 deletions exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,7 @@ class EmitterOutput:

def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
gm = exported_program.graph_module
output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
assert output_node is not None
output_node = gm.graph.output_node()

mutated_outputs: List[Optional[str]] = [
out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None
Expand Down
21 changes: 6 additions & 15 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,11 @@ def program(
)
]

output_node = [
node for node in lowered_exported_program.graph.nodes if node.op == "output"
]
assert len(output_node) == 1, "There should be only one output node"
output_node = lowered_exported_program.graph.output_node()

# Step 1. Cleaning up the graph before inserting the call_delegate node
# Remove the original output node
lowered_exported_program.graph.erase_node(output_node[0])
lowered_exported_program.graph.erase_node(output_node)

# Remove all the everything else except the input
for node in reversed(lowered_exported_program.graph.nodes):
Expand Down Expand Up @@ -269,11 +266,9 @@ def program(
)
# Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
# We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
original_output_nodes = [
node
for node in self._original_exported_program.graph.nodes
if node.op == "output"
][0].args[0]
original_output_nodes = (
self._original_exported_program.graph.output_node().args[0]
)

delegate_node.meta["spec"] = tuple(
[make_spec(node.meta["val"]) for node in original_output_nodes]
Expand Down Expand Up @@ -927,11 +922,7 @@ def _unsafe_adjust_original_program( # noqa: C901
raise RuntimeError(f"Invalid input spec {input_spec} received")

# Delete buffer mutations from the output which were consumed by the delegate
toplevel_output_node = None
for node in reversed(original_program.graph.nodes):
if node.op == "output":
toplevel_output_node = node
break
toplevel_output_node = original_program.graph.output_node()

assert toplevel_output_node is not None
assert (
Expand Down
12 changes: 2 additions & 10 deletions exir/passes/insert_write_back_for_buffers_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def _insert_copy(
Find the all the buffers and inputs that were mutated and insert copy_
operators to reflect mutations.
"""
output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
break
output_node = gm.graph.output_node()
assert output_node is not None
outputs = pytree.tree_flatten(output_node.args)[0]
assert len(outputs) == len(mutated_outputs)
Expand Down Expand Up @@ -139,11 +135,7 @@ def insert_write_back_for_buffers_pass(
if lifted_node is not None:
input_name_to_node[lifted_node] = input_node

output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
break
output_node = gm.graph.output_node()

# Grab the mutable buffer nodes in the outputs,
mutated_outputs: List[Optional[str]] = []
Expand Down
5 changes: 1 addition & 4 deletions exir/passes/quantize_io_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,8 @@ def quantize_output(exported_program, output_index):
output quantization.
"""
graph = exported_program.graph_module.graph
outputs = [n for n in graph.nodes if n.op == "output"]
if len(outputs) != 1:
raise NotImplementedError("Only 1 output node is supported")

output_node = outputs[0]
output_node = graph.output_node()
output_list = list(output_node.args[0])
if output_index >= len(output_list):
raise ValueError(
Expand Down
7 changes: 1 addition & 6 deletions exir/passes/weights_to_outputs_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,7 @@ def weights_to_outputs_pass(
inputs_to_params = gs.inputs_to_parameters

# Get output node
output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
break
assert output_node is not None
output_node = gm.graph.output_node()

# Get input nodes that are weights with an associated gradient
placeholder_nodes = [
Expand Down
12 changes: 2 additions & 10 deletions exir/tests/test_joint_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ def forward(self, x, y):
joint_ep = _export_forward_backward(ep)
edge = to_edge(joint_ep)

output_node = None
for node in edge.exported_program().graph.nodes:
if node.op == "output":
output_node = node
break
output_node = edge.exported_program().graph.output_node()

orig_outputs = len(output_node.args[0])

Expand All @@ -58,11 +54,7 @@ def forward(self, x, y):
if spec.kind == OutputKind.TOKEN
]

output_node = None
for node in et.exported_program().graph.nodes:
if node.op == "output":
output_node = node
break
output_node = et.exported_program().graph.output_node()

weight_outputs = len(output_node.args[0])

Expand Down
Loading