Skip to content

Commit 750b84b

Browse files
committed
Reset joint graph fake mode earlier, and more comprehensively
This bug was discovered by a stronger assert (which I will be posting in a follow up PR.) The explanation for this change is a bit long and windy, and I am not sure I entirely understand the situation myself. But here's what I think is going on. jansel's joint graph pattern matcher does something fairly unusual: in order to initialize the pattern in question, it (lazily) runs an aot_function invocation in order to trace out what the joint graph of a given pattern looks like (we ought not use aot_function, but we can't really do this until bdhirsh lands AOT Autograd export properly). However, this lazy initialization occurs within the context of a separate compilation, which has its own tracing context, and importantly, fake tensor mode. What we would like, is the pattern matcher lazy initialization fake tensor mode to be unrelated to whatever the ambient fake tensor mode of the graph we were actually compiling. We want these to be independent, because we don't really care what the current compiled graph is; this is a lazy init function, it could have gotten initialized during any compilation, it just happens to be initialized on this one. To prevent us from picking up the ambient fake mode, we have to do two things: we have to remove the tracing context (which stores a fake mode), and we have to also disable the ambiently active fake mode. In #99377 eellison proposed an alternative approach, where we reuse the fake mode. While this probably won't cause any errors, it's morally not the right thing to do, because you'll end up polluting the enclosing fake tensor mode with tensors that have nothing to do with the mode itself. This might fix #99286 but it's also possible that #99320 fixed it already. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: f572909 Pull Request resolved: #99391
1 parent a763d94 commit 750b84b

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

torch/_inductor/fx_passes/joint_graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import logging
33

44
import torch
5+
import torch._guards
6+
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
57
from ..._subclasses import FakeTensorMode
68
from .. import config
79
from ..pattern_matcher import PatternMatcherPass
@@ -14,7 +16,9 @@
1416
def lazy_init():
1517
from .fuse_attention import _sfdp_init
1618

17-
with FakeTensorMode():
19+
with torch._guards.tracing(
20+
None
21+
), maybe_disable_fake_tensor_mode(), FakeTensorMode():
1822
_sfdp_init()
1923

2024

torch/_inductor/pattern_matcher.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -664,14 +664,13 @@ def record_joint_graph(joint_graph, inputs, **kwargs):
664664
gm = clone_graph(joint_graph)
665665
return default_partition(joint_graph, inputs, **kwargs)
666666

667-
with torch._guards.tracing(None):
668-
aot_function(
669-
fn,
670-
lambda g, i: make_boxed_func(g),
671-
partition_fn=record_joint_graph,
672-
decompositions=select_decomp_table(),
673-
enable_log=False,
674-
)(*args)
667+
aot_function(
668+
fn,
669+
lambda g, i: make_boxed_func(g),
670+
partition_fn=record_joint_graph,
671+
decompositions=select_decomp_table(),
672+
enable_log=False,
673+
)(*args)
675674

676675
# remove in/out specs
677676
gm.graph._codegen = torch.fx.graph.CodeGen()

0 commit comments

Comments
 (0)