In [1]:
import torch

import header_code

torch._logging.set_logs(graph_breaks=True)

# Common Graph Breaks

Below are some common graph breaks and some workarounds.

## Incorrect Code
Your code might contain errors (meaning it doesn't execute even without `torch.compile`). In the example below, there's a typo in the `torch.sin` call due to an extra argument. **Always disable `torch.compile` to check if the code runs correctly.**

In [2]:
@torch.compile
def fn(x):
    y = torch.sin(x, x)
    return y

try:
    fn(torch.ones(3, 3))
except Exception as e:
    pass

Graph break in user code at C:\Users\Aditya\AppData\Local\Temp\ipykernel_13088\343837593.py:3
Graph Break Reason: TypeError when making fake tensor call
  Explanation: 


  Developer debug context: TypeError <built-in method sin of type object at 0x00007FF9076478E0>: sin() takes 1 positional argument but 2 were given

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0112.html
User code traceback:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\dev\pytorch\venv\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "C:\dev\pytorch\venv\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "C:\dev\pytorch\venv\Lib\site-packages\ipykernel\kernelapp.py", line 758, in start
    self.io_loop.start()
  File "C:\dev\pytorch\venv\Lib\site-packages\tornado\platform\as

Dynamo makes a best-effort attempt to hint if a graph break is caused by your code.
But it can still sometimes be difficult to tell from the logs if the graph break is caused by an error in your code,
is a more complicated graph break, or is a `torch.compile` bug. In order to differentiate, we recommend trying to run your code without `torch.compile` to see if you still get the error reported by the graph break.

## Data-dependent operations

`torch.compile` graph breaks on data-dependent operations such as data-dependent control flow (if-statements, loops with tensors) and direct tensor data accesses (`.item`, `.data_ptr`).

In [3]:
@torch.compile
def fn(x):
    y = x.sum()
    if y > 0:
        return x + y.item()
    return x - y.item()

print(fn(torch.ones(3, 3)))

Graph break in user code at C:\Users\Aditya\AppData\Local\Temp\ipykernel_13088\3495555842.py:4
Graph Break Reason: Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.

  Developer debug context: attempted to jump with TensorVariable()

User code traceback:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\dev\pytorch\venv\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "C:\dev\pytorch\venv\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "C:\dev\pytorch\venv\Lib\site-packages\ipykernel\kernelapp.py", line 7

Graph break from `Tensor.item()`, consider setting:
    torch._dynamo.config.capture_scalar_outputs = True
or:
    env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
to include these operations in the captured graph.

Graph break: from user code at:
  File "C:\Users\Aditya\AppData\Local\Temp\ipykernel_13088\3495555842.py", line 5, in torch_dynamo_resume_in_fn_at_4
    return x + y.item()




Graph break in user code at C:\Users\Aditya\AppData\Local\Temp\ipykernel_13088\3495555842.py:5
Graph Break Reason: Unsupported Tensor.item() call with capture_scalar_outputs=False
  Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.
  Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph.

  Developer debug context: call_method TensorVariable() item () {}

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html
User code traceback:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\dev\pytorch\venv\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "C:\dev\pytorch\venv\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instan

tensor([[10., 10., 10.],
        [10., 10., 10.],
        [10., 10., 10.]])


The general workaround for these graph breaks is to avoid doing data-dependent operations. Some specific workarounds are:

- If your control flow doesn't actually depend on data values, consider modifying your code to perform control flow on constants.

In [4]:
# old
x = torch.randn(3, 3)
@torch.compile
def fn(y):
    if x.sum() > 0:
        return y + x
    else:
        return y - x

print(fn(torch.ones(3, 3)))

Graph break in user code at C:\Users\Aditya\AppData\Local\Temp\ipykernel_13088\2410325100.py:5
Graph Break Reason: Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.

  Developer debug context: attempted to jump with TensorVariable()

User code traceback:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\dev\pytorch\venv\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "C:\dev\pytorch\venv\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "C:\dev\pytorch\venv\Lib\site-packages\ipykernel\kernelapp.py", line 7

tensor([[ 2.0617,  3.2001, -0.2806],
        [ 0.9175,  2.3327, -0.1929],
        [ 3.8927,  0.1718,  1.1279]])


In [5]:
# new
x = torch.randn(3, 3)
cond = (x.sum() > 0).item()
@torch.compile
def fn(y):
    if cond:
        return y + x
    else:
        return y - x

print(fn(torch.ones(3, 3)))

tensor([[ 0.4605,  2.4148,  0.7653],
        [ 0.5703,  1.4693, -0.2955],
        [ 1.5892,  0.9290,  1.1816]])


- Use higher-order ops like {ref}`cond` in place of data-dependent control flow

In [6]:
# old
@torch.compile
def fn(x):
    if x.sum() > 0:
        return x + 1
    return x - 1

print(fn(torch.ones(3, 3)))

Graph break in user code at C:\Users\Aditya\AppData\Local\Temp\ipykernel_13088\520574912.py:4
Graph Break Reason: Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.

  Developer debug context: attempted to jump with TensorVariable()

User code traceback:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\dev\pytorch\venv\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "C:\dev\pytorch\venv\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "C:\dev\pytorch\venv\Lib\site-packages\ipykernel\kernelapp.py", line 75

tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])


In [7]:
# new
@torch.compile
def fn(x):
    return torch.cond(
        x.sum() > 0,
        lambda x: x + 1,
        lambda x: x - 1,
        (x,),
    )

print(fn(torch.ones(3, 3)))

tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])


- If you have a `.item()` call, try `torch._dynamo.config.capture_scalar_outputs = True`
or `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`.
- Wrap problematic parts of the function in a custom operator

## Printing and logging

Printing/logging/issuing warnings will result in a graph break.
You can try working around this by using `torch._dynamo.config.reorderable_logging_functions`.
This config is used to reorder logging functions so that they are called at the end of the
traced function, thus avoiding a graph break.
However, the logged contents may differ if, for example, a mutation occurs.

In [8]:
torch._dynamo.config.reorderable_logging_functions.add(print)

@torch.compile
def fn(x):
    x += 1
    print("log!")
    return torch.sin(x)

print(fn(torch.ones(3, 3)))

log!
tensor([[0.9093, 0.9093, 0.9093],
        [0.9093, 0.9093, 0.9093],
        [0.9093, 0.9093, 0.9093]])
