Skip to content

Commit

Permalink
Fix bug in graph partitioner
Browse files Browse the repository at this point in the history
Summary: Title

Test Plan: CI

Differential Revision: D56688411
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed Apr 29, 2024
1 parent da44d2f commit c78101d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
41 changes: 41 additions & 0 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,54 @@ def _compiling_state_context():
if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
gm.meta.update(mod.meta)


original_output_names = ()
for node in gm.graph.nodes:
if node.op == "output":
original_output_names = node.args

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm)


updated_output_names = ()
for node in gm.graph.nodes:
if node.op == "output":
updated_output_names = node.args

# It could be updated because replace_set_grad_with_hop_pass adds new
# getitem nodes into the graph.
assert len(original_output_names) == len(updated_output_names)
old_output_name_to_new_output_name = {}
for k, v in zip(*original_output_names, *updated_output_names):
if k is not None and v is not None:
old_output_name_to_new_output_name[k.name] = v.name
# If there is None in the end, it should be true for updated output names
else:
assert k is None and v is None

buffers_to_mutate_copy = graph_signature.buffers_to_mutate.copy()
user_inputs_to_mutate_copy = graph_signature.user_inputs_to_mutate.copy()
for k in old_output_name_to_new_output_name:
if k in graph_signature.buffers_to_mutate:
graph_signature.buffers_to_mutate[old_output_name_to_new_output_name[k]] = buffers_to_mutate_copy[k]
if k not in old_output_name_to_new_output_name.values():
del graph_signature.buffers_to_mutate[k]

if k in graph_signature.user_inputs_to_mutate:
graph_signature.user_inputs_to_mutate[old_output_name_to_new_output_name[k]] = user_inputs_to_mutate_copy[k]
if k not in old_output_name_to_new_output_name.values():
del graph_signature.user_inputs_to_mutate[k]

for i, k in enumerate(graph_signature.user_outputs):
if k in old_output_name_to_new_output_name:
new_k = old_output_name_to_new_output_name[k]
graph_signature.user_outputs[i] = new_k

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
Expand Down
10 changes: 10 additions & 0 deletions torch/fx/passes/split_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,16 @@ def instantiate_node_partition_mapping(node):
orig_mod_attr_nodes: List[Node] = [
orig_mod_env[key] for key in partition.inputs
]

# We actually need to insert the placeholder nodes in the original order
# otherwise graph signature will be wrong.
original_order = {node.name: i for i, node in enumerate(m.graph.nodes) if node.op == "placeholder"}
for node in orig_mod_attr_nodes:
if node.name not in original_order:
original_order[node.name] = float("inf")

Check failure on line 486 in torch/fx/passes/split_module.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [assignment]

Incompatible types in assignment (expression has type "float", target has type "int")

orig_mod_attr_nodes.sort(key=lambda node: original_order[node.name])

# Construct GraphModule for this partition
for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
if node in already_constructed_attr_nodes:
Expand Down

0 comments on commit c78101d

Please sign in to comment.