# Introduction to `torch.compile`

**Author:** William Wen

`torch.compile` is the latest method to speed up your PyTorch code! `torch.compile` makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, all while requiring minimal code changes.

In this tutorial, we cover basic `torch.compile` usage, and demonstrate the advantages of `torch.compile` over previous PyTorch compiler solutions, such as TorchScript and FX Tracing.

**Required pip dependencies for this tutorial**

* `torch >= 2.0`
* `torchvision`
* `numpy`
* `scipy`
* `tabulate`

**NOTE:** a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in order to reproduce the speedup numbers shown below and documented elsewhere.

In [1]:
import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )

I0221 120624.445 _utils_internal.py:282] NCCL_DEBUG env var is set to None
I0221 120624.447 _utils_internal.py:291] NCCL_DEBUG is WARN from /etc/nccl.conf


## Basic Usage

`torch.compile` is included in the latest PyTorch. Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly binary. If Triton is still missing, try installing `torchtriton` via `pip` (`pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"` for CUDA 11.7).

Arbitrary Python functions can be optimized by passing the callable to torch.compile. We can then call the returned optimized function in place of the original function.

In [2]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

W0221 12:06:53.524671 505258 torch/_utils_internal.py:610] [0/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


tensor([[ 1.6922e+00, -1.3079e+00,  1.8881e+00,  9.2692e-01,  7.1303e-01,
          6.2709e-01, -1.0181e+00,  1.1799e+00,  1.0656e+00,  5.0181e-02],
        [-3.8351e-02,  1.7844e+00,  1.5407e+00,  1.1245e+00,  8.8961e-01,
          1.1635e+00,  9.4375e-01,  1.4531e+00,  1.4312e-01,  7.8938e-01],
        [-2.4845e-01,  4.9754e-01, -1.4061e-03,  1.6899e+00,  7.9759e-01,
          8.6161e-01,  6.9971e-01,  1.2607e+00,  8.1494e-01, -1.7857e-02],
        [ 1.5330e+00, -5.1327e-01,  1.1237e-01, -1.5571e-01,  1.2751e+00,
          8.7541e-01,  9.2658e-01,  3.9375e-01,  6.8470e-01,  8.7922e-01],
        [ 1.6161e+00,  6.5746e-01, -1.3410e-01,  8.6125e-01,  1.6745e+00,
         -5.9303e-01, -2.9838e-01,  1.4963e+00,  2.0170e-01, -9.7567e-01],
        [ 1.5343e+00,  1.2795e-01, -9.8592e-01,  1.8861e+00,  5.3889e-02,
         -1.3971e-01,  7.5570e-01,  1.6637e+00,  1.2802e-01,  5.4864e-01],
        [ 1.0105e+00,  5.4418e-01,  1.0590e+00, -4.9262e-02,  1.3561e-01,
          5.3599e-01,  1.3287e-0

Alternatively, we can decorate the function.

In [3]:
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)

@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
print(opt_foo2(t1, t2))

W0221 12:06:54.632669 505258 torch/_utils_internal.py:610] [1/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


tensor([[ 1.1658, -1.2850,  0.9785,  0.4676,  0.6177,  0.5203, -0.5915, -0.4507,
         -0.5545,  1.0223],
        [ 0.6054,  0.2063,  0.3110,  0.7613, -0.3570,  0.2013,  0.9728,  1.9195,
         -0.8023,  1.7299],
        [ 0.8754,  1.9269, -0.2669,  1.5307,  0.4050, -0.1303,  1.7022,  1.7485,
          1.4283,  1.0293],
        [ 1.7531, -0.6572, -0.0596,  0.7978, -0.1523,  0.0730,  0.3352, -0.7632,
          1.5219,  0.4020],
        [ 1.3965, -0.8392,  1.4538,  0.3531, -0.0178, -0.1744,  1.2756,  0.0310,
          1.2749, -0.0575],
        [ 1.0903,  0.0743, -1.1493,  1.7084, -0.0548,  0.0088, -0.0098,  0.4674,
          1.1412,  1.6406],
        [ 1.1101, -0.6055, -0.7934,  0.0426, -0.3504,  1.9223,  1.7483,  0.1735,
          0.2343,  0.5436],
        [ 1.0473,  0.9783,  0.6461, -0.2204, -0.2482,  1.1901,  0.4395,  1.7900,
          0.7811,  0.0416],
        [-0.4371,  0.2625,  0.4305,  1.4590,  1.1109, -0.1579,  1.3633,  1.9848,
          0.6311,  1.7412],
        [ 0.1202,  

We can also optimize `torch.nn.Module` instances.

In [4]:
t = torch.randn(10, 100)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(t))

W0221 12:07:02.432690 505258 torch/_utils_internal.py:610] [2/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


tensor([[0.6194, 0.0000, 0.0000, 1.3402, 0.0000, 0.5395, 0.8070, 0.7041, 0.7132,
         0.0000],
        [0.0000, 0.3945, 0.5750, 0.4356, 0.9109, 0.2037, 0.2142, 0.4301, 0.0000,
         0.1682],
        [0.2559, 0.0000, 0.0000, 0.5103, 0.0000, 0.5222, 0.7330, 1.1618, 0.3316,
         0.0000],
        [0.3484, 0.3980, 0.0000, 0.1898, 0.2904, 0.3463, 0.6790, 0.8740, 0.0000,
         0.0000],
        [0.1930, 0.0370, 0.0000, 0.0000, 0.1458, 0.1000, 0.0000, 0.0000, 0.0340,
         0.0000],
        [0.0000, 0.6415, 0.6900, 0.2303, 0.0000, 0.6913, 0.2326, 0.1811, 0.5319,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.6115, 0.5116, 0.2032, 0.0000, 0.0000,
         0.0000],
        [0.0281, 0.0989, 0.0000, 0.0000, 0.0000, 0.0573, 0.0000, 0.2625, 1.0456,
         0.6681],
        [0.9763, 0.0000, 0.0000, 0.0000, 1.0746, 0.0543, 0.0000, 0.0000, 0.0000,
         0.0919],
        [0.0445, 0.2499, 0.9341, 0.0000, 0.9033, 0.0000, 0.6043, 0.2127, 0.0000,
         0.0000]], grad_fn=<

## `torch.compile` and Nested Calls

Nested function calls within the decorated function will also be compiled.

In [5]:
def nested_function(x):
    return torch.sin(x)

@torch.compile
def outer_function(x, y):
    a = nested_function(x)
    b = torch.cos(y)
    return a + b

print(outer_function(t1, t2))

W0221 12:07:03.411172 505258 torch/_utils_internal.py:610] [3/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


tensor([[ 1.1658, -1.2850,  0.9785,  0.4676,  0.6177,  0.5203, -0.5915, -0.4507,
         -0.5545,  1.0223],
        [ 0.6054,  0.2063,  0.3110,  0.7613, -0.3570,  0.2013,  0.9728,  1.9195,
         -0.8023,  1.7299],
        [ 0.8754,  1.9269, -0.2669,  1.5307,  0.4050, -0.1303,  1.7022,  1.7485,
          1.4283,  1.0293],
        [ 1.7531, -0.6572, -0.0596,  0.7978, -0.1523,  0.0730,  0.3352, -0.7632,
          1.5219,  0.4020],
        [ 1.3965, -0.8392,  1.4538,  0.3531, -0.0178, -0.1744,  1.2756,  0.0310,
          1.2749, -0.0575],
        [ 1.0903,  0.0743, -1.1493,  1.7084, -0.0548,  0.0088, -0.0098,  0.4674,
          1.1412,  1.6406],
        [ 1.1101, -0.6055, -0.7934,  0.0426, -0.3504,  1.9223,  1.7483,  0.1735,
          0.2343,  0.5436],
        [ 1.0473,  0.9783,  0.6461, -0.2204, -0.2482,  1.1901,  0.4395,  1.7900,
          0.7811,  0.0416],
        [-0.4371,  0.2625,  0.4305,  1.4590,  1.1109, -0.1579,  1.3633,  1.9848,
          0.6311,  1.7412],
        [ 0.1202,  

In the same fashion, when compiling a module all sub-modules and methods within it, that are not in a skip list, are also compiled.

In [6]:
class OuterModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.inner_module = MyModule()
        self.outer_lin = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.inner_module(x)
        return torch.nn.functional.relu(self.outer_lin(x))

outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))

W0221 12:07:11.506459 505258 torch/_utils_internal.py:610] [4/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


tensor([[0.0972, 0.0000],
        [0.0000, 0.0973],
        [0.2504, 0.3826],
        [0.1147, 0.1527],
        [0.0873, 0.4107],
        [0.0000, 0.2425],
        [0.0508, 0.0000],
        [0.1172, 0.2697],
        [0.0000, 0.0000],
        [0.0120, 0.0000]], grad_fn=<CompiledFunctionBackward>)


We can also disable some functions from being compiled by using `torch.compiler.disable`. Suppose you want to disable the tracing on just the `complex_function` function, but want to continue the tracing back in `complex_conjugate`. In this case, you can use `torch.compiler.disable(recursive=False)` option. Otherwise, the default is `recursive=True`.

In [7]:
def complex_conjugate(z):
    return torch.conj(z)

@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
    # Assuming this function cause problems in the compilation
    z = torch.complex(real, imag)
    return complex_conjugate(z)

def outer_function():
    real = torch.tensor([2, 3], dtype=torch.float32)
    imag = torch.tensor([4, 5], dtype=torch.float32)
    z = complex_function(real, imag)
    return torch.abs(z)

# Try to compile the outer_function
try:
    opt_outer_function = torch.compile(outer_function)
    print(opt_outer_function())
except Exception as e:
    print("Compilation of outer_function failed:", e)

W0221 12:07:19.525780 505258 torch/_utils_internal.py:610] [5/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
W0221 12:07:21.791521 505258 torch/_utils_internal.py:610] [6/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
W0221 12:07:23.833631 505258 torch/_utils_internal.py:610] [7/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


tensor([4.4721, 5.8310])


## Best Practices and Recommendations

When you use `torch.compile`, the compiler will try to recursively compile every function call inside the target function or module inside the target function or module that is not in a skip list (such as built-ins, some functions in the `torch.*` namespace).

Best Practices:

1. **Top-Level Compilation:** One approach is to compile at the highest level possible (i.e., when the top-level module is initialized/called) and selectively disable compilation when encountering excessive graph breaks or errors. If there are still many compile issues, compile individual subcomponents instead.

2. **Modular Testing:** Test individual functions and modules with `torch.compile` before integrating them into larger models to isolate potential issues.

3. **Disable Compilation Selectively:** If certain functions or sub-modules cannot be handled by `torch.compile`, use the `torch.compiler.disable` context managers to recursively exclude them from compilation.

4. **Compile Leaf Functions First:** In complex models with multiple nested functions and modules, start by compiling the leaf functions or modules first. For more information see [TorchDynamo APIs for fine-grained tracing](https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html).

## Demonstrating Speedups

Let’s now demonstrate that using `torch.compile` can speed up real models. We will compare standard eager mode and `torch.compile` by evaluating and training a `torchvision` model on random data.

Before we start, we need to define some utility functions.

In [8]:
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

First, let’s compare inference.

Note that in the call to `torch.compile`, we have the additional `mode` argument, which we will discuss below.

In [9]:
model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 3.491248779296875


W0221 12:09:45.348326 505258 torch/_utils_internal.py:610] [0/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


compile: 158.399140625


Notice that `torch.compile` takes a lot longer to complete compared to eager. This is because `torch.compile` compiles the model into optimized kernels as it executes. In our example, the structure of the model doesn’t change, and so recompilation is not needed. So if we run our optimized model several more times, we should see a significant improvement compared to eager.

In [10]:
eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager eval time 0: 0.028770719528198242
eager eval time 1: 0.017812095642089843
eager eval time 2: 0.017108640670776366
eager eval time 3: 0.016643135070800782
eager eval time 4: 0.016642688751220703
eager eval time 5: 0.016266719818115234
eager eval time 6: 0.016456928253173828
eager eval time 7: 0.017984512329101563
eager eval time 8: 0.016288799285888673
eager eval time 9: 0.016147775650024412
~~~~~~~~~~
compile eval time 0: 1.3922445068359375
compile eval time 1: 0.006485631942749023
compile eval time 2: 0.007281856060028076
compile eval time 3: 0.005829951763153076
compile eval time 4: 0.005723040103912353
compile eval time 5: 0.005723616123199463
compile eval time 6: 0.005710368156433106
compile eval time 7: 0.005717984199523926
compile eval time 8: 0.0057062082290649414
compile eval time 9: 0.005712672233581543
~~~~~~~~~~
(eval) eager median: 0.016642911911010742, compile median: 0.0057233281135559075, speedup: 2.907908052937138x
~~~~~~~~~~


And indeed, we can see that running our model with `torch.compile` results in a significant speedup. Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model’s architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant.

You may also see different speedup results depending on the chosen `mode` argument. The `"reduce-overhead"` mode uses CUDA graphs to further reduce the overhead of Python. For your own models, you may need to experiment with different modes to maximize speedup. You can read more about modes [here](https://pytorch.org/get-started/pytorch-2.0/#user-experience).

You may might also notice that the second time we run our model with `torch.compile` is significantly slower than the other runs, although it is much faster than the first run. This is because the `"reduce-overhead"` mode runs a few warm-up iterations for CUDA graphs.

For general PyTorch benchmarking, you can try using `torch.utils.benchmark` instead of the `timed` function we defined above. We wrote our own timing function in this tutorial to show `torch.compile`’s compilation latency.

Now, let’s consider comparing training.

In [11]:
model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager train time 0: 0.5444177856445312
eager train time 1: 0.06874956512451172
eager train time 2: 0.04701449584960937
eager train time 3: 0.04661782455444336
eager train time 4: 0.04657648086547852
eager train time 5: 0.04545606231689453
eager train time 6: 0.04611872100830078
eager train time 7: 0.045399742126464845
eager train time 8: 0.04607027053833008
eager train time 9: 0.04458924865722656
~~~~~~~~~~


W0221 12:10:09.374782 505258 torch/_utils_internal.py:610] [1/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


W0221 12:12:08.545296 505258 torch/_utils_internal.py:610] [2/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


W0221 12:12:24.900599 505258 torch/_utils_internal.py:610] [3/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


W0221 12:13:51.216206 505258 torch/_utils_internal.py:610] [2/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


W0221 12:14:08.911017 505258 torch/_logging/_internal.py:1133] [4/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


W0221 12:14:08.955705 505258 torch/_utils_internal.py:610] [4/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


W0221 12:14:08.990026 505258 torch/_utils_internal.py:610] [5/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


W0221 12:16:19.480191 505258 torch/_utils_internal.py:610] [6/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


W0221 12:16:19.570131 505258 torch/_utils_internal.py:610] [7/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


W0221 12:16:19.582355 505258 torch/_utils_internal.py:610] [8/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


compile train time 0: 370.23490625


compile train time 1: 13.7634853515625
compile train time 2: 0.02405561637878418
compile train time 3: 0.014702560424804688
compile train time 4: 0.013050848007202149
compile train time 5: 0.013004351615905761
compile train time 6: 0.01349516773223877
compile train time 7: 0.013042079925537109
compile train time 8: 0.012937919616699219
compile train time 9: 0.012921695709228515
~~~~~~~~~~
(train) eager median: 0.04634760093688965, compile median: 0.013273007869720459, speedup: 3.491868715200707x
~~~~~~~~~~


Again, we can see that `torch.compile` takes longer in the first iteration, as it must compile the model, but in subsequent iterations, we see significant speedups compared to eager.

We remark that the speedup numbers presented in this tutorial are for demonstration purposes only. Official speedup values can be seen at the [TorchInductor performance dashboard](https://hud.pytorch.org/benchmark/compilers).

## Comparison to TorchScript and FX Tracing

We have seen that `torch.compile` can speed up PyTorch code. Why else should we use `torch.compile` over existing PyTorch compiler solutions, such as TorchScript or FX Tracing? Primarily, the advantage of `torch.compile` lies in its ability to handle arbitrary Python code with minimal changes to existing code.

One case that `torch.compile` can handle that other compiler solutions struggle with is data-dependent control flow (the `if x.sum() < 0:` line below).

In [12]:
def f1(x, y):
    if x.sum() < 0:
        return -y
    return y

# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

TorchScript tracing `f1` results in silently incorrect results, since only the actual control flow path is traced.

In [13]:
traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

traced 1, 1: True
traced 1, 2: False


[W 250221 12:16:33 4005623017:2] Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!


FX tracing `f1` results in an error due to the presence of data-dependent control flow.

In [14]:
import traceback as tb
try:
    torch.fx.symbolic_trace(f1)
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipykernel_505258/1451191637.py", line 3, in <module>
    torch.fx.symbolic_trace(f1)
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/fx/_symbolic_trace.py", line 1312, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/_dynamo/eval_frame.py", line 764, in _fn
    return fn(*args, **kwargs)
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/fx/_symbolic_trace.py", line 836, in trace
    (self.create_arg(fn(*args)),),
  File "/tmp/ipykernel_505258/4005623017.py", line 2, in f1
    if x.sum() < 0:
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/fx/proxy.py", line 562, in __bool__
    return self.tracer.to_bool(self)
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-

If we provide a value for `x` as we try to FX trace `f1`, then we run into the same problem as TorchScript tracing, as the data-dependent control flow is removed in the traced function.

In [15]:
fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1})
print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2)))
print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2)))

fx 1, 1: True
fx 1, 2: False


Now we can see that `torch.compile` correctly handles data-dependent control flow.

In [16]:
# Reset since we are using a different mode.
torch._dynamo.reset()

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)

W0221 12:16:40.668892 505258 torch/_utils_internal.py:610] [0/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
W0221 12:16:40.683961 505258 torch/_utils_internal.py:610] [1/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


compile 1, 1: True


W0221 12:16:47.611421 505258 torch/_utils_internal.py:610] [2/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


compile 1, 2: True
~~~~~~~~~~


TorchScript scripting can handle data-dependent control flow, but this solution comes with its own set of problems. Namely, TorchScript scripting can require major code changes and will raise errors when unsupported Python is used.

In the example below, we forget TorchScript type annotations and we receive a TorchScript error because the input type for argument `y`, an `int`, does not match with the default argument type, `torch.Tensor`.

In [17]:
def f2(x, y):
    return x + y

inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipykernel_505258/895162831.py", line 9, in <module>
    script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor


However, `torch.compile` is easily able to handle `f2`.

In [18]:
compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)

W0221 12:16:54.226542 505258 torch/_utils_internal.py:610] [3/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


compile 2: True
~~~~~~~~~~


Another case that `torch.compile` handles well compared to previous compilers solutions is the usage of non-PyTorch functions.

In [19]:
import scipy
def f3(x):
    x = x * 2
    x = scipy.fft.dct(x.numpy())
    x = torch.from_numpy(x)
    x = x * 2
    return x

TorchScript tracing treats results from non-PyTorch function calls as constants, and so our results can be silently wrong.

In [20]:
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))

traced 3: False


[W 250221 12:16:54 123398134:4] Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!


TorchScript scripting and FX tracing disallow non-PyTorch function calls.

In [21]:
try:
    torch.jit.script(f3)
except:
    tb.print_exc()

try:
    torch.fx.symbolic_trace(f3)
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipykernel_505258/1991772511.py", line 2, in <module>
    torch.jit.script(f3)
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/jit/_script.py", line 1439, in script
    ret = _script_impl(
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/jit/_script.py", line 1212, in _script_impl
    fn = torch._C._jit_script_compile(
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/_jit_internal.py", line 1234, in _try_get_dispatched_fn
    return boolean_dispatched.get(fn)
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/runtime/lib/python3.10/weakref.py", line 453, in get
    return self.data.get(ref(key),default)
TypeError: cannot create weak reference to 'uarray._Function' object
Traceback (most recent call last):
  File "/tmp/ipykernel_505258/1991772511.py", line 7,

In comparison, `torch.compile` is easily able to handle the non-PyTorch function call.

In [22]:
compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))

W0221 12:17:01.069489 505258 torch/_utils_internal.py:610] [4/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
W0221 12:17:01.098259 505258 torch/_utils_internal.py:610] [5/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
W0221 12:17:01.125145 505258 torch/_utils_internal.py:610] [6/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
W0221 12:17:01.146956 505258 torch/_utils_internal.py:610] [7/0_1] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we 

compile 3: True


## TorchDynamo and FX Graphs

One important component of `torch.compile` is TorchDynamo. TorchDynamo is responsible for JIT compiling arbitrary Python code into [FX graphs](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph), which can then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode during runtime and detecting calls to PyTorch operations.

Normally, TorchInductor, another component of `torch.compile`, further compiles the FX graphs into optimized kernels, but TorchDynamo allows for different backends to be used. In order to inspect the FX graphs that TorchDynamo outputs, let us create a custom backend that outputs the FX graph and simply returns the graph’s unoptimized forward method.

In [23]:
from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward

# Reset since we are using a different backend.
torch._dynamo.reset()

opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])

custom backend called with FX graph:
opcode         name                                                                                                         target                                                                                                       args                                                                                                                                                                                                                                                                                                                                                                                                                                                     kwargs
-------------  -----------------------------------------------------------------------------------------------------------  -----------------------------------------------------------------------------------------------------------  -------------------------------------------------

W0221 12:17:08.351746 505258 torch/_utils_internal.py:610] [0/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


tensor([[-0.1652,  0.2909,  0.4834,  ..., -0.1868,  0.0739, -0.3966],
        [-0.1317,  0.4502,  0.3291,  ..., -0.2256,  0.2245, -0.4491],
        [-0.1168,  0.2676,  0.1739,  ..., -0.1554,  0.2463, -0.3396],
        ...,
        [-0.2693,  0.5311,  0.2218,  ..., -0.0489,  0.1891, -0.3440],
        [-0.1732,  0.1918,  0.3215,  ..., -0.1224,  0.0871, -0.5373],
        [-0.2747,  0.3258,  0.1076,  ..., -0.2460,  0.1852, -0.4085]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

Using our custom backend, we can now see how TorchDynamo is able to handle data-dependent control flow. Consider the function below, where the line `if b.sum() < 0` is the source of data-dependent control flow.

In [24]:
def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)

W0221 12:17:08.524575 505258 torch/_utils_internal.py:610] [1/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
W0221 12:17:08.547182 505258 torch/_utils_internal.py:610] [2/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
W0221 12:17:08.570407 505258 torch/_utils_internal.py:610] [3/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


custom backend called with FX graph:
opcode         name    target                                                  args         kwargs
-------------  ------  ------------------------------------------------------  -----------  --------
placeholder    l_a_    L_a_                                                    ()           {}
placeholder    l_b_    L_b_                                                    ()           {}
call_function  abs_1   <built-in method abs of type object at 0x55698d211960>  (l_a_,)      {}
call_function  add     <built-in function add>                                 (abs_1, 1)   {}
call_function  x       <built-in function truediv>                             (l_a_, add)  {}
call_method    sum_1   sum                                                     (l_b_,)      {}
call_function  lt      <built-in function lt>                                  (sum_1, 0)   {}
output         output  output                                                  ((x, lt),)   {}
cus

tensor([ 0.1642,  0.9337, -0.0785, -0.2383, -0.3237, -0.0518, -0.2949, -0.3723,
        -0.4776, -0.7030])

The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above):

1. `x = a / (torch.abs(a) + 1)`

2. `b = b * -1; return x * b`

3. `return x * b`

When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph.

Let’s investigate by example how TorchDynamo would step through `bar`. If `b.sum() < 0`, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 2. On the other hand, if `not b.sum() < 0`, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 3.

This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph.

We can see where TorchDynamo breaks the graph by using `torch._dynamo.explain`:

In [25]:
# Reset since we are using a different backend.
torch._dynamo.reset()
explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10))
print(explain_output)

W0221 12:17:08.612438 505258 torch/_utils_internal.py:610] [0/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
W0221 12:17:08.633665 505258 torch/_utils_internal.py:610] [1/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


Graph Count: 2
Graph Break Count: 1
Op Count: 6
Break Reasons:
  Break Reason 1:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file /tmp/ipykernel_505258/3263660924.py, line 3 in bar>
Ops per Graph:
  Ops 1:
    <built-in method abs of type object at 0x55698d211960>
    <built-in function add>
    <built-in function truediv>
    <built-in function lt>
  Ops 2:
    <built-in function mul>
    <built-in function mul>
Out Guards:
  Guard 1:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 2:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 3:
    Name: "L['b'].sum"
    Source: local
    Create Function: HASATTR
    Guard Types: ['HASATTR']
    Code List: ["hasattr(L[

In order to maximize speedup, graph breaks should be limited. We can force TorchDynamo to raise an error upon the first graph break encountered by `using fullgraph=True`:

In [26]:
opt_bar = torch.compile(bar, fullgraph=True)
try:
    opt_bar(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()

W0221 12:17:08.664191 505258 torch/_utils_internal.py:610] [0/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.
Traceback (most recent call last):
  File "/tmp/ipykernel_505258/3610564610.py", line 3, in <module>
    opt_bar(torch.randn(10), torch.randn(10))
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/_dynamo/eval_frame.py", line 585, in _fn
    return fn(*args, **kwargs)
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/_dynamo/convert_frame.py", line 1393, in __call__
    return self._torchdynamo_orig_callable(
  File "/mnt/xarfuse/uid-279513/197d327d-seed-nspid4026531836_cgpid370380623-ns-4026531841/torch/_dynamo/convert_frame.py", line 585, in __call__
    return _compile(
  File "/mnt/xarfuse/uid-279513/197d327d-see

And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups.

In [27]:
opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))

W0221 12:18:27.491386 505258 torch/_utils_internal.py:610] [1/0] `guard_latency_us` is not a valid field in GeneratedDynamoCompileLogEntry, so we won't log it to Scuba or Hive.Please fix `test_fb.py:test_compilation_metrics_logger_in_sync` by adding it to the Logger configuration.


tensor([[ 0.0294,  0.2753,  0.1290,  ...,  0.1568,  0.0533, -0.1716],
        [-0.1070,  0.1701, -0.0249,  ...,  0.1851,  0.0640, -0.2015],
        [-0.0762,  0.2020,  0.1385,  ...,  0.3403,  0.0474, -0.0348],
        ...,
        [ 0.1323,  0.2607,  0.1046,  ...,  0.3043,  0.0453, -0.2665],
        [-0.0415,  0.3157,  0.1270,  ...,  0.2658, -0.0085, -0.2415],
        [-0.1058,  0.2242,  0.1776,  ...,  0.3469,  0.0636, -0.0558]],
       device='cuda:0', grad_fn=<CompiledFunctionBackward>)


We can use `torch.export` (from PyTorch 2.1+) to extract a single, exportable FX graph from the input PyTorch program. The exported graph is intended to be run on different (i.e. Python-less) environments. One important restriction is that the `torch.export` does not support graph breaks. Please check [this tutorial](https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html) for more details on `torch.export`.

# Conclusion

In this tutorial, we introduced `torch.compile` by covering basic usage, demonstrating speedups over eager mode, comparing to previous PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions with FX graphs. We hope that you will give `torch.compile` a try!