Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Bug - Missing with context on the graph break instruction #207

@anijain2305

Description

@anijain2305

Repro

from functools import partial
import torch
import torchdynamo

print = partial(print, flush=True)


def reversible(x):
    print("Hello world")  # Cause graph break so inline fails
    return torch.sin(torch.cos(x))


def fn(x):
    torch._C._set_grad_enabled(False)

    with torch.enable_grad():
        a = torch.sin(x)
        b = reversible(a)
        c = torch.sigmoid(b)
        c.sum().backward()
        return x.grad


x = torch.randn(4, requires_grad=True)
x.grad = None

ref = fn(x)
print("Eager done")

# torchdynamo.config.trace = True
# torchdynamo.config.debug = True

x.grad = None
with torchdynamo.optimize("eager"):
    res = fn(x)
print(res)

This fails with the following error

Traceback (most recent call last):
  File "with_test.py", line 39, in <module>
    res = fn(x)
  File "with_test.py", line 17, in fn
    def fn(x):
  File "with_test.py", line 17, in fn
    def fn(x):
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/_tensor.py", line 399, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Following is the bytecode


ORIGINAL BYTECODE fn with_test.py 13
 14           0 LOAD_GLOBAL              0 (torch)
              2 LOAD_ATTR                1 (_C)
              4 LOAD_METHOD              2 (_set_grad_enabled)
              6 LOAD_CONST               1 (False)
              8 CALL_METHOD              1
             10 POP_TOP

 16          12 LOAD_GLOBAL              0 (torch)
             14 LOAD_METHOD              3 (enable_grad)
             16 CALL_METHOD              0
             18 SETUP_WITH              60 (to 80)
             20 POP_TOP

 17          22 LOAD_GLOBAL              0 (torch)
             24 LOAD_METHOD              4 (sin)
             26 LOAD_FAST                0 (x)
             28 CALL_METHOD              1
             30 STORE_FAST               1 (a)

 18          32 LOAD_GLOBAL              5 (reversible)
             34 LOAD_FAST                1 (a)
             36 CALL_FUNCTION            1
             38 STORE_FAST               2 (b)

 19          40 LOAD_GLOBAL              0 (torch)
             42 LOAD_METHOD              6 (sigmoid)
             44 LOAD_FAST                2 (b)
             46 CALL_METHOD              1
             48 STORE_FAST               3 (c)

 20          50 LOAD_FAST                3 (c)
             52 LOAD_METHOD              7 (sum)
             54 CALL_METHOD              0
             56 LOAD_METHOD              8 (backward)
             58 CALL_METHOD              0
             60 POP_TOP

 21          62 LOAD_FAST                0 (x)
             64 LOAD_ATTR                9 (grad)
             66 POP_BLOCK
             68 ROT_TWO
             70 BEGIN_FINALLY
             72 WITH_CLEANUP_START
             74 WITH_CLEANUP_FINISH
             76 POP_FINALLY              0
             78 RETURN_VALUE
        >>   80 WITH_CLEANUP_START
             82 WITH_CLEANUP_FINISH
             84 END_FINALLY
             86 LOAD_CONST               0 (None)
             88 RETURN_VALUE

MODIFIED BYTECODE
 13           0 LOAD_GLOBAL             11 (__compiled_fn_0)
              2 LOAD_FAST                0 (x)
              4 CALL_FUNCTION            1
              6 STORE_FAST               4 (___graph_out_0)
              8 LOAD_GLOBAL             10 (__import_torch)
             10 LOAD_ATTR                3 (enable_grad)
             12 LOAD_GLOBAL              5 (reversible)
             14 LOAD_FAST                4 (___graph_out_0)
             16 LOAD_CONST               2 (0)
             18 BINARY_SUBSCR
             20 CALL_FUNCTION            1
             22 LOAD_GLOBAL             12 (__resume_at_38_1)
             24 ROT_THREE
             26 LOAD_FAST                0 (x)
             28 CALL_FUNCTION            3
             30 RETURN_VALUE

In the modified bytecode, call to the reversible function is not happening inside the with context. Therefore, reversible function is called with grad flag disabled and it triggers the above error.

Note that the __compiled_fn_0 and the resume function, both correctly keep track of the with context. We are missing the with context only for the instruction(s) that cause the graph break (CALL_FUNCTION here).

Also note that Dynamo has a special handling for no_grad and enable_grad - https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/symbolic_convert.py#L421. For any other type of context manager, we will just break the graph on SETUP_WITH. So, a quick fix (but with poorer coverage) is to always break on SETUP_WITH.

@jansel suggested to hand-write the python bytecode for try .. except block and conditionally insert at https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/symbolic_convert.py#L137

>>> import dis
>>> def foo():
...    set_grad_true()
...    try:
...      user_inst()
...    finally:
...      set_grad_false()
... 
>>> dis.dis(foo)
  2           0 LOAD_GLOBAL              0 (set_grad_true)
              2 CALL_FUNCTION            0
              4 POP_TOP

  3           6 SETUP_FINALLY           16 (to 24)

  4           8 LOAD_GLOBAL              1 (user_inst)
             10 CALL_FUNCTION            0
             12 POP_TOP
             14 POP_BLOCK

  6          16 LOAD_GLOBAL              2 (set_grad_false)
             18 CALL_FUNCTION            0
             20 POP_TOP
             22 JUMP_FORWARD             8 (to 32)
        >>   24 LOAD_GLOBAL              2 (set_grad_false)
             26 CALL_FUNCTION            0
             28 POP_TOP
             30 RERAISE
        >>   32 LOAD_CONST               0 (None)
             34 RETURN_VALUE
>>>

My initial experience is that it is little more tedious as we have to setup the resume call and its arguments as well.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions