diff --git a/tests/test_repros.py b/tests/test_repros.py index fb95a77160..adadaa12ea 100755 --- a/tests/test_repros.py +++ b/tests/test_repros.py @@ -1105,3 +1105,27 @@ def fn3(): with torchdynamo.optimize("eager"): res = fn3() self.assertTrue(same(ref, res)) + + def test_with_on_graph_break_inst(self): + def reversible(x): + print("Hello world") # Cause graph break so inline fails + return torch.sin(torch.cos(x)) + + def fn(x): + with torch.enable_grad(): + a = torch.sin(x) + b = reversible(a) + c = torch.sigmoid(b) + c.sum().backward() + return x.grad + + x = torch.randn(3, requires_grad=True) + x.grad = None + with torch.no_grad(): + ref = fn(x) + + x.grad = None + with torchdynamo.optimize("eager"): + with torch.no_grad(): + res = fn(x) + self.assertTrue(same(ref, res)) diff --git a/torchdynamo/codegen.py b/torchdynamo/codegen.py index 3933992aad..07128e8988 100644 --- a/torchdynamo/codegen.py +++ b/torchdynamo/codegen.py @@ -298,3 +298,9 @@ def load_import_from(self, module_name, object_name): self ) ) + + def create_begin_finally(self): + if sys.version_info < (3, 8): + return self.create_load_const(None) + else: + return create_instruction("BEGIN_FINALLY") diff --git a/torchdynamo/symbolic_convert.py b/torchdynamo/symbolic_convert.py index 7be0945b32..aada8e9754 100644 --- a/torchdynamo/symbolic_convert.py +++ b/torchdynamo/symbolic_convert.py @@ -62,6 +62,7 @@ from .variables.misc import ClosureVariable from .variables.misc import ContextManagerVariable from .variables.misc import GetAttrVariable +from .variables.misc import GradModeVariable from .variables.misc import PythonModuleVariable from .variables.misc import UnknownVariable from .variables.misc import WithExitFunctionVariable @@ -149,11 +150,33 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): self.restore_graphstate(state) self.output.compile_subgraph(self) self.popn(push - dis.stack_effect(inst.opcode, inst.arg)) - self.output.add_output_instructions([inst]) + for _ in range(push): self.push(UnknownVariable()) + + resume_call_insts = self.create_call_resume_at(self.next_instruction) + # Check if there is a block stack entry with GradModeVariable. And + # wrap the instruction causing the graph break inside a try..finally + # block. See more details at + # https://github.com/pytorch/torchdynamo/issues/207 + cleanup = [] + if len(self.block_stack) == 1 and isinstance( + self.block_stack[0].with_context, GradModeVariable + ): + ctx_variable = self.block_stack[0].with_context + + cg = PyCodegen(self) + setup_finally, cleanup = ctx_variable.reconstruct( + cg, resume_call_insts[0] + ) + self.output.add_output_instructions(setup_finally) + + self.output.add_output_instructions([inst]) + + # Add the cleanup instructions from try..finally block + self.output.add_output_instructions(cleanup) self.output.add_output_instructions( - self.create_call_resume_at(self.next_instruction) + resume_call_insts, ) return wrapper diff --git a/torchdynamo/variables/misc.py b/torchdynamo/variables/misc.py index 03461e03f8..9c27870816 100644 --- a/torchdynamo/variables/misc.py +++ b/torchdynamo/variables/misc.py @@ -1,4 +1,5 @@ import inspect +import sys import types from typing import Dict from typing import List @@ -135,6 +136,113 @@ def _change_mode(tx, value): ), torch._C._set_grad_enabled(value) + def reconstruct(self, codegen, target_inst=None): + """ + Generate following Python Bytecode + Python 3.8 + 0 LOAD_GLOBAL 0 (torch) + 2 LOAD_ATTR 1 (_C) + 4 LOAD_METHOD 2 (_set_grad_enable) + 6 LOAD_CONST 1 (False) + 8 CALL_METHOD 1 + 10 POP_TOP + + 12 SETUP_FINALLY 10 (to 24) + + 14 LOAD_GLOBAL 3 (user_inst) + 16 CALL_FUNCTION 0 + 18 POP_TOP + 20 POP_BLOCK + 22 BEGIN_FINALLY + + 24 LOAD_GLOBAL 0 (torch) + 26 LOAD_ATTR 1 (_C) + 28 LOAD_METHOD 2 (_set_grad_enable) + 30 LOAD_CONST 2 (True) + 32 CALL_METHOD 1 + 34 POP_TOP + 36 END_FINALLY + 38 LOAD_CONST 0 (None) + 40 RETURN_VALUE + + Instructions 0-10 and 24-34 call torch._C.set_grad_enable(True/False) + + Python 3.9, 3.10 + 0 LOAD_GLOBAL 0 (torch) + 2 LOAD_ATTR 1 (_C) + 4 LOAD_METHOD 2 (_set_grad_enable) + 6 LOAD_CONST 1 (False) + 8 CALL_METHOD 1 + 10 POP_TOP + + 12 SETUP_FINALLY 22 (to 36) + + 14 LOAD_GLOBAL 3 (user_inst) + 16 CALL_FUNCTION 0 + 18 POP_TOP + 20 POP_BLOCK + + 22 LOAD_GLOBAL 0 (torch) + 24 LOAD_ATTR 1 (_C) + 26 LOAD_METHOD 2 (_set_grad_enable) + 28 LOAD_CONST 2 (True) + 30 CALL_METHOD 1 + 32 POP_TOP + + 34 JUMP_FORWARD 14 (to 50) + + 36 LOAD_GLOBAL 0 (torch) + 38 LOAD_ATTR 1 (_C) + 40 LOAD_METHOD 2 (_set_grad_enable) + 42 LOAD_CONST 2 (True) + 44 CALL_METHOD 1 + 46 POP_TOP + 48 RERAISE + + 50 LOAD_CONST 0 (None) + 52 RETURN_VALUE + + """ + if self.target_mode == self.original_mode: + return ([], []) + + def set_grad_insts(mode): + global_torch_source = codegen.tx.import_source("torch") + attr_source = AttrSource(global_torch_source, "_C._set_grad_enabled") + load_set_grad_enabled_insts = attr_source.reconstruct(codegen) + return [ + *load_set_grad_enabled_insts, + codegen.create_load_const(mode), + create_instruction("CALL_FUNCTION", 1), + create_instruction("POP_TOP"), + ] + + init_block = set_grad_insts(self.target_mode) + finally_block = set_grad_insts(self.original_mode) + setup_final_inst = create_instruction("SETUP_FINALLY", target=finally_block[0]) + prologue = init_block + [setup_final_inst] + + # Generate the epilogue - starts with 20 POP_BLOCK and ends at 34 POP_TOP + if sys.version_info < (3, 9): + # Generate the prologue that ends with setup_finally + epilogue = [ + create_instruction("POP_BLOCK"), + codegen.create_begin_finally(), + *finally_block, + create_instruction("END_FINALLY"), + ] + else: + except_block = set_grad_insts(self.original_mode) + epilogue = [ + create_instruction("POP_BLOCK"), + *except_block, + create_instruction("JUMP_FORWARD", target=target_inst), + *finally_block, + create_instruction("RERAISE"), + ] + + return (prologue, epilogue) + class WithExitFunctionVariable(VariableTracker): def __init__(self, ctx: VariableTracker, target, **kwargs):