Skip to content

Commit 26f1da0

Browse files
soulitzerpytorchmergebot
authored andcommitted
Fix node traversal when setting up stacktrace preservation hooks (#118252)
We only want to traverse over each node in the graph exactly once, and we do that by inserting nodes into the "seen" set. The issue is that we forget to check the "seen" set when inserting the root nodes. Typically that is not a problem, because the root nodes are from the different outputs and thus usually correspond to different nodes. With split_with_sizes, though all of the outputs correspond to the same node, ands this leads to the node being iterated over 3 times, and 3 sets of hooks being attached to the same node. Pull Request resolved: #118252 Approved by: https://github.com/zou3519 ghstack dependencies: #117552, #118234, #118249
1 parent b8bd3bb commit 26f1da0

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

test/dynamo/test_aot_autograd.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,25 @@ def _prepare_model_args():
885885
),
886886
)
887887

888+
def test_split_with_sizes_aot_autograd_cleans_up_traceback_meta(self):
889+
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
890+
891+
def fn(result, split_sizes):
892+
rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
893+
return rs
894+
895+
example_inputs = (
896+
torch.randn(32, requires_grad=True),
897+
torch.tensor((7, 16, 9)),
898+
)
899+
outs = fn(*example_inputs)
900+
setup_stacktrace_preservation_hooks([out.grad_fn for out in outs])
901+
with fx_traceback.preserve_node_meta():
902+
(outs[0].sum() + outs[1].sum() + outs[2].sum()).backward()
903+
904+
self.assertNotIn("grad_fn_seq_nr", fx_traceback.current_meta)
905+
self.assertNotIn("in_grad_fn", fx_traceback.current_meta)
906+
888907
# https://github.com/pytorch/pytorch/issues/110121
889908
def test_aot_export_joint_simple_repro(self):
890909
class Mod(torch.nn.Module):

torch/_functorch/_aot_autograd/logging_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def iter_graph(roots):
6565
seen = set()
6666
q = collections.deque() # type: ignore[var-annotated]
6767
for node in roots:
68-
if node is not None:
68+
if node is not None and node not in seen:
6969
seen.add(node)
7070
q.append(node)
7171

0 commit comments

Comments
 (0)