Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,3 +1105,27 @@ def fn3():
with torchdynamo.optimize("eager"):
res = fn3()
self.assertTrue(same(ref, res))

def test_with_on_graph_break_inst(self):
def reversible(x):
print("Hello world") # Cause graph break so inline fails
return torch.sin(torch.cos(x))

def fn(x):
with torch.enable_grad():
a = torch.sin(x)
b = reversible(a)
c = torch.sigmoid(b)
c.sum().backward()
return x.grad

x = torch.randn(3, requires_grad=True)
x.grad = None
with torch.no_grad():
ref = fn(x)

x.grad = None
with torchdynamo.optimize("eager"):
with torch.no_grad():
res = fn(x)
self.assertTrue(same(ref, res))
6 changes: 6 additions & 0 deletions torchdynamo/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,9 @@ def load_import_from(self, module_name, object_name):
self
)
)

def create_begin_finally(self):
if sys.version_info < (3, 8):
return self.create_load_const(None)
else:
return create_instruction("BEGIN_FINALLY")
27 changes: 25 additions & 2 deletions torchdynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from .variables.misc import ClosureVariable
from .variables.misc import ContextManagerVariable
from .variables.misc import GetAttrVariable
from .variables.misc import GradModeVariable
from .variables.misc import PythonModuleVariable
from .variables.misc import UnknownVariable
from .variables.misc import WithExitFunctionVariable
Expand Down Expand Up @@ -149,11 +150,33 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
self.restore_graphstate(state)
self.output.compile_subgraph(self)
self.popn(push - dis.stack_effect(inst.opcode, inst.arg))
self.output.add_output_instructions([inst])

for _ in range(push):
self.push(UnknownVariable())

resume_call_insts = self.create_call_resume_at(self.next_instruction)
# Check if there is a block stack entry with GradModeVariable. And
# wrap the instruction causing the graph break inside a try..finally
# block. See more details at
# https://github.com/pytorch/torchdynamo/issues/207
cleanup = []
if len(self.block_stack) == 1 and isinstance(
self.block_stack[0].with_context, GradModeVariable
):
ctx_variable = self.block_stack[0].with_context

cg = PyCodegen(self)
setup_finally, cleanup = ctx_variable.reconstruct(
cg, resume_call_insts[0]
)
self.output.add_output_instructions(setup_finally)

self.output.add_output_instructions([inst])

# Add the cleanup instructions from try..finally block
self.output.add_output_instructions(cleanup)
self.output.add_output_instructions(
self.create_call_resume_at(self.next_instruction)
resume_call_insts,
)

return wrapper
Expand Down
108 changes: 108 additions & 0 deletions torchdynamo/variables/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import sys
import types
from typing import Dict
from typing import List
Expand Down Expand Up @@ -135,6 +136,113 @@ def _change_mode(tx, value):
),
torch._C._set_grad_enabled(value)

def reconstruct(self, codegen, target_inst=None):
"""
Generate following Python Bytecode
Python 3.8
0 LOAD_GLOBAL 0 (torch)
2 LOAD_ATTR 1 (_C)
4 LOAD_METHOD 2 (_set_grad_enable)
6 LOAD_CONST 1 (False)
8 CALL_METHOD 1
10 POP_TOP

12 SETUP_FINALLY 10 (to 24)

14 LOAD_GLOBAL 3 (user_inst)
16 CALL_FUNCTION 0
18 POP_TOP
20 POP_BLOCK
22 BEGIN_FINALLY

24 LOAD_GLOBAL 0 (torch)
26 LOAD_ATTR 1 (_C)
28 LOAD_METHOD 2 (_set_grad_enable)
30 LOAD_CONST 2 (True)
32 CALL_METHOD 1
34 POP_TOP
36 END_FINALLY
38 LOAD_CONST 0 (None)
40 RETURN_VALUE

Instructions 0-10 and 24-34 call torch._C.set_grad_enable(True/False)

Python 3.9, 3.10
0 LOAD_GLOBAL 0 (torch)
2 LOAD_ATTR 1 (_C)
4 LOAD_METHOD 2 (_set_grad_enable)
6 LOAD_CONST 1 (False)
8 CALL_METHOD 1
10 POP_TOP

12 SETUP_FINALLY 22 (to 36)

14 LOAD_GLOBAL 3 (user_inst)
16 CALL_FUNCTION 0
18 POP_TOP
20 POP_BLOCK

22 LOAD_GLOBAL 0 (torch)
24 LOAD_ATTR 1 (_C)
26 LOAD_METHOD 2 (_set_grad_enable)
28 LOAD_CONST 2 (True)
30 CALL_METHOD 1
32 POP_TOP

34 JUMP_FORWARD 14 (to 50)

36 LOAD_GLOBAL 0 (torch)
38 LOAD_ATTR 1 (_C)
40 LOAD_METHOD 2 (_set_grad_enable)
42 LOAD_CONST 2 (True)
44 CALL_METHOD 1
46 POP_TOP
48 RERAISE

50 LOAD_CONST 0 (None)
52 RETURN_VALUE

"""
if self.target_mode == self.original_mode:
return ([], [])

def set_grad_insts(mode):
global_torch_source = codegen.tx.import_source("torch")
attr_source = AttrSource(global_torch_source, "_C._set_grad_enabled")
load_set_grad_enabled_insts = attr_source.reconstruct(codegen)
return [
*load_set_grad_enabled_insts,
codegen.create_load_const(mode),
create_instruction("CALL_FUNCTION", 1),
create_instruction("POP_TOP"),
]

init_block = set_grad_insts(self.target_mode)
finally_block = set_grad_insts(self.original_mode)
setup_final_inst = create_instruction("SETUP_FINALLY", target=finally_block[0])
prologue = init_block + [setup_final_inst]

# Generate the epilogue - starts with 20 POP_BLOCK and ends at 34 POP_TOP
if sys.version_info < (3, 9):
# Generate the prologue that ends with setup_finally
epilogue = [
create_instruction("POP_BLOCK"),
codegen.create_begin_finally(),
*finally_block,
create_instruction("END_FINALLY"),
]
else:
except_block = set_grad_insts(self.original_mode)
epilogue = [
create_instruction("POP_BLOCK"),
*except_block,
create_instruction("JUMP_FORWARD", target=target_inst),
*finally_block,
create_instruction("RERAISE"),
]

return (prologue, epilogue)


class WithExitFunctionVariable(VariableTracker):
def __init__(self, ctx: VariableTracker, target, **kwargs):
Expand Down