-
Notifications
You must be signed in to change notification settings - Fork 129
Bug - Missing with context on the graph break instruction #207
Description
Repro
from functools import partial
import torch
import torchdynamo
print = partial(print, flush=True)
def reversible(x):
print("Hello world") # Cause graph break so inline fails
return torch.sin(torch.cos(x))
def fn(x):
torch._C._set_grad_enabled(False)
with torch.enable_grad():
a = torch.sin(x)
b = reversible(a)
c = torch.sigmoid(b)
c.sum().backward()
return x.grad
x = torch.randn(4, requires_grad=True)
x.grad = None
ref = fn(x)
print("Eager done")
# torchdynamo.config.trace = True
# torchdynamo.config.debug = True
x.grad = None
with torchdynamo.optimize("eager"):
res = fn(x)
print(res)
This fails with the following error
Traceback (most recent call last):
File "with_test.py", line 39, in <module>
res = fn(x)
File "with_test.py", line 17, in fn
def fn(x):
File "with_test.py", line 17, in fn
def fn(x):
File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/_tensor.py", line 399, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Following is the bytecode
ORIGINAL BYTECODE fn with_test.py 13
14 0 LOAD_GLOBAL 0 (torch)
2 LOAD_ATTR 1 (_C)
4 LOAD_METHOD 2 (_set_grad_enabled)
6 LOAD_CONST 1 (False)
8 CALL_METHOD 1
10 POP_TOP
16 12 LOAD_GLOBAL 0 (torch)
14 LOAD_METHOD 3 (enable_grad)
16 CALL_METHOD 0
18 SETUP_WITH 60 (to 80)
20 POP_TOP
17 22 LOAD_GLOBAL 0 (torch)
24 LOAD_METHOD 4 (sin)
26 LOAD_FAST 0 (x)
28 CALL_METHOD 1
30 STORE_FAST 1 (a)
18 32 LOAD_GLOBAL 5 (reversible)
34 LOAD_FAST 1 (a)
36 CALL_FUNCTION 1
38 STORE_FAST 2 (b)
19 40 LOAD_GLOBAL 0 (torch)
42 LOAD_METHOD 6 (sigmoid)
44 LOAD_FAST 2 (b)
46 CALL_METHOD 1
48 STORE_FAST 3 (c)
20 50 LOAD_FAST 3 (c)
52 LOAD_METHOD 7 (sum)
54 CALL_METHOD 0
56 LOAD_METHOD 8 (backward)
58 CALL_METHOD 0
60 POP_TOP
21 62 LOAD_FAST 0 (x)
64 LOAD_ATTR 9 (grad)
66 POP_BLOCK
68 ROT_TWO
70 BEGIN_FINALLY
72 WITH_CLEANUP_START
74 WITH_CLEANUP_FINISH
76 POP_FINALLY 0
78 RETURN_VALUE
>> 80 WITH_CLEANUP_START
82 WITH_CLEANUP_FINISH
84 END_FINALLY
86 LOAD_CONST 0 (None)
88 RETURN_VALUE
MODIFIED BYTECODE
13 0 LOAD_GLOBAL 11 (__compiled_fn_0)
2 LOAD_FAST 0 (x)
4 CALL_FUNCTION 1
6 STORE_FAST 4 (___graph_out_0)
8 LOAD_GLOBAL 10 (__import_torch)
10 LOAD_ATTR 3 (enable_grad)
12 LOAD_GLOBAL 5 (reversible)
14 LOAD_FAST 4 (___graph_out_0)
16 LOAD_CONST 2 (0)
18 BINARY_SUBSCR
20 CALL_FUNCTION 1
22 LOAD_GLOBAL 12 (__resume_at_38_1)
24 ROT_THREE
26 LOAD_FAST 0 (x)
28 CALL_FUNCTION 3
30 RETURN_VALUE
In the modified bytecode, call to the reversible function is not happening inside the with
context. Therefore, reversible
function is called with grad
flag disabled and it triggers the above error.
Note that the __compiled_fn_0
and the resume function, both correctly keep track of the with
context. We are missing the with
context only for the instruction(s) that cause the graph break (CALL_FUNCTION
here).
Also note that Dynamo has a special handling for no_grad
and enable_grad
- https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/symbolic_convert.py#L421. For any other type of context manager, we will just break the graph on SETUP_WITH
. So, a quick fix (but with poorer coverage) is to always break on SETUP_WITH
.
@jansel suggested to hand-write the python bytecode for try .. except block and conditionally insert at https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/symbolic_convert.py#L137
>>> import dis
>>> def foo():
... set_grad_true()
... try:
... user_inst()
... finally:
... set_grad_false()
...
>>> dis.dis(foo)
2 0 LOAD_GLOBAL 0 (set_grad_true)
2 CALL_FUNCTION 0
4 POP_TOP
3 6 SETUP_FINALLY 16 (to 24)
4 8 LOAD_GLOBAL 1 (user_inst)
10 CALL_FUNCTION 0
12 POP_TOP
14 POP_BLOCK
6 16 LOAD_GLOBAL 2 (set_grad_false)
18 CALL_FUNCTION 0
20 POP_TOP
22 JUMP_FORWARD 8 (to 32)
>> 24 LOAD_GLOBAL 2 (set_grad_false)
26 CALL_FUNCTION 0
28 POP_TOP
30 RERAISE
>> 32 LOAD_CONST 0 (None)
34 RETURN_VALUE
>>>
My initial experience is that it is little more tedious as we have to setup the resume call and its arguments as well.