From a7f0cf97b95a0afe98697f7f301b5b50d0214217 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 5 May 2022 22:13:50 +0000 Subject: [PATCH 1/6] Try finally block for with context on graph break instruction --- tests/test_repros.py | 24 +++++++++++++ torchdynamo/symbolic_convert.py | 18 ++++++++++ torchdynamo/variables/misc.py | 62 +++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+) diff --git a/tests/test_repros.py b/tests/test_repros.py index fb95a77160..5f640adc46 100755 --- a/tests/test_repros.py +++ b/tests/test_repros.py @@ -1105,3 +1105,27 @@ def fn3(): with torchdynamo.optimize("eager"): res = fn3() self.assertTrue(same(ref, res)) + + def test_with_on_graph_break_inst(self): + def reversible(x): + print("Hello world") # Cause graph break so inline fails + return torch.sin(torch.cos(x)) + + def fn(x): + torch._C._set_grad_enabled(False) + + with torch.enable_grad(): + a = torch.sin(x) + b = reversible(a) + c = torch.sigmoid(b) + c.sum().backward() + return x.grad + + x = torch.randn(3, requires_grad=True) + x.grad = None + ref = fn(x) + + x.grad = None + with torchdynamo.optimize("eager"): + res = fn(x) + self.assertTrue(same(ref, res)) diff --git a/torchdynamo/symbolic_convert.py b/torchdynamo/symbolic_convert.py index 7be0945b32..dcd5c8392f 100644 --- a/torchdynamo/symbolic_convert.py +++ b/torchdynamo/symbolic_convert.py @@ -62,6 +62,7 @@ from .variables.misc import ClosureVariable from .variables.misc import ContextManagerVariable from .variables.misc import GetAttrVariable +from .variables.misc import GradModeVariable from .variables.misc import PythonModuleVariable from .variables.misc import UnknownVariable from .variables.misc import WithExitFunctionVariable @@ -149,7 +150,24 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): self.restore_graphstate(state) self.output.compile_subgraph(self) self.popn(push - dis.stack_effect(inst.opcode, inst.arg)) + + # Check if there is a block stack entry with GradModeVariable. And + # wrap the instruction causing the graph break inside a try..finally + # block. See more details at + # https://github.com/pytorch/torchdynamo/issues/207 + cleanup = [] + if len(self.block_stack) == 1 and isinstance( + self.block_stack[0].with_context, GradModeVariable + ): + ctx_variable = self.block_stack[0].with_context + cg = PyCodegen(self) + setup_finally, cleanup = ctx_variable.reconstruct(cg) + self.output.add_output_instructions(setup_finally) + self.output.add_output_instructions([inst]) + + # Add the cleanup instructions from try..finally block + self.output.add_output_instructions(cleanup) for _ in range(push): self.push(UnknownVariable()) self.output.add_output_instructions( diff --git a/torchdynamo/variables/misc.py b/torchdynamo/variables/misc.py index 03461e03f8..77d4278572 100644 --- a/torchdynamo/variables/misc.py +++ b/torchdynamo/variables/misc.py @@ -135,6 +135,68 @@ def _change_mode(tx, value): ), torch._C._set_grad_enabled(value) + def reconstruct(self, codegen): + """ + Generate following Python Bytecode + 0 LOAD_GLOBAL 0 (torch) + 2 LOAD_ATTR 1 (_C) + 4 LOAD_METHOD 2 (_set_grad_enable) + 6 LOAD_CONST 1 (False) + 8 CALL_METHOD 1 + 10 POP_TOP + + 12 SETUP_FINALLY 10 (to 24) + + 14 LOAD_GLOBAL 3 (user_inst) + 16 CALL_FUNCTION 0 + 18 POP_TOP + 20 POP_BLOCK + 22 BEGIN_FINALLY + + 24 LOAD_GLOBAL 0 (torch) + 26 LOAD_ATTR 1 (_C) + 28 LOAD_METHOD 2 (_set_grad_enable) + 30 LOAD_CONST 2 (True) + 32 CALL_METHOD 1 + 34 POP_TOP + 36 END_FINALLY + 38 LOAD_CONST 0 (None) + 40 RETURN_VALUE + + Instructions 0-10 and 24-34 call torch._C.set_grad_enable(True/False) + + """ + if self.target_mode == self.original_mode: + return ([], []) + + def set_grad_insts(mode): + codegen.load_import_from("torch", "_C") + codegen.load_import_from("torch._C", "_set_grad_enabled") + return [ + codegen.create_load_global("torch"), + codegen.create_load_attr("_C"), + create_instruction("LOAD_METHOD", 2, "_set_grad_enabled"), + codegen.create_load_const(mode), + create_instruction("CALL_METHOD", 1), + create_instruction("POP_TOP"), + ] + + init_block = set_grad_insts(self.target_mode) + finally_block = set_grad_insts(self.original_mode) + + # Generate the prologue that ends with setup_finally + setup_final_inst = create_instruction("SETUP_FINALLY", target=finally_block[0]) + prologue = init_block + [setup_final_inst] + + # Generate the epilogue - starts with 20 POP_BLOCK and ends at 34 POP_TOP + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("BEGIN_FINALLY"), + *finally_block, + create_instruction("END_FINALLY"), + ] + return (prologue, epilogue) + class WithExitFunctionVariable(VariableTracker): def __init__(self, ctx: VariableTracker, target, **kwargs): From ed60183900d9a1abb71458e75f648f4ae942aa63 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 5 May 2022 22:30:02 +0000 Subject: [PATCH 2/6] fix test --- tests/test_repros.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_repros.py b/tests/test_repros.py index 5f640adc46..adadaa12ea 100755 --- a/tests/test_repros.py +++ b/tests/test_repros.py @@ -1112,8 +1112,6 @@ def reversible(x): return torch.sin(torch.cos(x)) def fn(x): - torch._C._set_grad_enabled(False) - with torch.enable_grad(): a = torch.sin(x) b = reversible(a) @@ -1123,9 +1121,11 @@ def fn(x): x = torch.randn(3, requires_grad=True) x.grad = None - ref = fn(x) + with torch.no_grad(): + ref = fn(x) x.grad = None with torchdynamo.optimize("eager"): - res = fn(x) + with torch.no_grad(): + res = fn(x) self.assertTrue(same(ref, res)) From 3ea02e70f29250d6a1ec6bffe139d31475e5b2a3 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 6 May 2022 00:15:16 +0000 Subject: [PATCH 3/6] Support >= 3.9 --- torchdynamo/symbolic_convert.py | 13 +++++-- torchdynamo/variables/misc.py | 65 ++++++++++++++++++++++++++++----- 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/torchdynamo/symbolic_convert.py b/torchdynamo/symbolic_convert.py index dcd5c8392f..aada8e9754 100644 --- a/torchdynamo/symbolic_convert.py +++ b/torchdynamo/symbolic_convert.py @@ -151,6 +151,10 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): self.output.compile_subgraph(self) self.popn(push - dis.stack_effect(inst.opcode, inst.arg)) + for _ in range(push): + self.push(UnknownVariable()) + + resume_call_insts = self.create_call_resume_at(self.next_instruction) # Check if there is a block stack entry with GradModeVariable. And # wrap the instruction causing the graph break inside a try..finally # block. See more details at @@ -160,18 +164,19 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): self.block_stack[0].with_context, GradModeVariable ): ctx_variable = self.block_stack[0].with_context + cg = PyCodegen(self) - setup_finally, cleanup = ctx_variable.reconstruct(cg) + setup_finally, cleanup = ctx_variable.reconstruct( + cg, resume_call_insts[0] + ) self.output.add_output_instructions(setup_finally) self.output.add_output_instructions([inst]) # Add the cleanup instructions from try..finally block self.output.add_output_instructions(cleanup) - for _ in range(push): - self.push(UnknownVariable()) self.output.add_output_instructions( - self.create_call_resume_at(self.next_instruction) + resume_call_insts, ) return wrapper diff --git a/torchdynamo/variables/misc.py b/torchdynamo/variables/misc.py index 77d4278572..d31b7ea0ad 100644 --- a/torchdynamo/variables/misc.py +++ b/torchdynamo/variables/misc.py @@ -1,4 +1,5 @@ import inspect +import sys import types from typing import Dict from typing import List @@ -135,9 +136,10 @@ def _change_mode(tx, value): ), torch._C._set_grad_enabled(value) - def reconstruct(self, codegen): + def reconstruct(self, codegen, target_inst=None): """ Generate following Python Bytecode + Python 3.8 0 LOAD_GLOBAL 0 (torch) 2 LOAD_ATTR 1 (_C) 4 LOAD_METHOD 2 (_set_grad_enable) @@ -165,6 +167,41 @@ def reconstruct(self, codegen): Instructions 0-10 and 24-34 call torch._C.set_grad_enable(True/False) + Python 3.9, 3.10 + 0 LOAD_GLOBAL 0 (torch) + 2 LOAD_ATTR 1 (_C) + 4 LOAD_METHOD 2 (_set_grad_enable) + 6 LOAD_CONST 1 (False) + 8 CALL_METHOD 1 + 10 POP_TOP + + 12 SETUP_FINALLY 22 (to 36) + + 14 LOAD_GLOBAL 3 (user_inst) + 16 CALL_FUNCTION 0 + 18 POP_TOP + 20 POP_BLOCK + + 22 LOAD_GLOBAL 0 (torch) + 24 LOAD_ATTR 1 (_C) + 26 LOAD_METHOD 2 (_set_grad_enable) + 28 LOAD_CONST 2 (True) + 30 CALL_METHOD 1 + 32 POP_TOP + + 34 JUMP_FORWARD 14 (to 50) + + 36 LOAD_GLOBAL 0 (torch) + 38 LOAD_ATTR 1 (_C) + 40 LOAD_METHOD 2 (_set_grad_enable) + 42 LOAD_CONST 2 (True) + 44 CALL_METHOD 1 + 46 POP_TOP + 48 RERAISE + + 50 LOAD_CONST 0 (None) + 52 RETURN_VALUE + """ if self.target_mode == self.original_mode: return ([], []) @@ -183,18 +220,28 @@ def set_grad_insts(mode): init_block = set_grad_insts(self.target_mode) finally_block = set_grad_insts(self.original_mode) - - # Generate the prologue that ends with setup_finally setup_final_inst = create_instruction("SETUP_FINALLY", target=finally_block[0]) prologue = init_block + [setup_final_inst] # Generate the epilogue - starts with 20 POP_BLOCK and ends at 34 POP_TOP - epilogue = [ - create_instruction("POP_BLOCK"), - create_instruction("BEGIN_FINALLY"), - *finally_block, - create_instruction("END_FINALLY"), - ] + if sys.version_info < (3, 9): + # Generate the prologue that ends with setup_finally + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("BEGIN_FINALLY"), + *finally_block, + create_instruction("END_FINALLY"), + ] + else: + except_block = set_grad_insts(self.original_mode) + epilogue = [ + create_instruction("POP_BLOCK"), + *except_block, + create_instruction("JUMP_FORWARD", target=target_inst), + *finally_block, + create_instruction("RERAISE"), + ] + return (prologue, epilogue) From 07be4e060a69022ad53f0422846d0da59dcf098a Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 6 May 2022 00:49:43 +0000 Subject: [PATCH 4/6] Support python 3.7 --- torchdynamo/codegen.py | 6 ++++++ torchdynamo/variables/misc.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/torchdynamo/codegen.py b/torchdynamo/codegen.py index 3933992aad..07128e8988 100644 --- a/torchdynamo/codegen.py +++ b/torchdynamo/codegen.py @@ -298,3 +298,9 @@ def load_import_from(self, module_name, object_name): self ) ) + + def create_begin_finally(self): + if sys.version_info < (3, 8): + return self.create_load_const(None) + else: + return create_instruction("BEGIN_FINALLY") diff --git a/torchdynamo/variables/misc.py b/torchdynamo/variables/misc.py index d31b7ea0ad..6d2436024e 100644 --- a/torchdynamo/variables/misc.py +++ b/torchdynamo/variables/misc.py @@ -228,7 +228,7 @@ def set_grad_insts(mode): # Generate the prologue that ends with setup_finally epilogue = [ create_instruction("POP_BLOCK"), - create_instruction("BEGIN_FINALLY"), + codegen.create_begin_finally(), *finally_block, create_instruction("END_FINALLY"), ] From 5d292749e843d22f2885465e722f62f1670ce72d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 6 May 2022 05:12:48 +0000 Subject: [PATCH 5/6] Comments --- torchdynamo/variables/misc.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchdynamo/variables/misc.py b/torchdynamo/variables/misc.py index 6d2436024e..4153c9d476 100644 --- a/torchdynamo/variables/misc.py +++ b/torchdynamo/variables/misc.py @@ -207,14 +207,12 @@ def reconstruct(self, codegen, target_inst=None): return ([], []) def set_grad_insts(mode): - codegen.load_import_from("torch", "_C") - codegen.load_import_from("torch._C", "_set_grad_enabled") return [ codegen.create_load_global("torch"), codegen.create_load_attr("_C"), - create_instruction("LOAD_METHOD", 2, "_set_grad_enabled"), + codegen.create_load_attr("_set_grad_enabled"), codegen.create_load_const(mode), - create_instruction("CALL_METHOD", 1), + create_instruction("CALL_FUNCTION", 1), create_instruction("POP_TOP"), ] From 046932d6df8f21dbdea649552279c225f03b3da0 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 6 May 2022 05:34:01 +0000 Subject: [PATCH 6/6] Replacing the global load with GlobalSource and reconstruct --- torchdynamo/variables/misc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchdynamo/variables/misc.py b/torchdynamo/variables/misc.py index 4153c9d476..9c27870816 100644 --- a/torchdynamo/variables/misc.py +++ b/torchdynamo/variables/misc.py @@ -207,10 +207,11 @@ def reconstruct(self, codegen, target_inst=None): return ([], []) def set_grad_insts(mode): + global_torch_source = codegen.tx.import_source("torch") + attr_source = AttrSource(global_torch_source, "_C._set_grad_enabled") + load_set_grad_enabled_insts = attr_source.reconstruct(codegen) return [ - codegen.create_load_global("torch"), - codegen.create_load_attr("_C"), - codegen.create_load_attr("_set_grad_enabled"), + *load_set_grad_enabled_insts, codegen.create_load_const(mode), create_instruction("CALL_FUNCTION", 1), create_instruction("POP_TOP"),