torch.compile
is the latest method to speed up your PyTorch
code in torch >= 2.0.0
! torch.compile
makes PyTorch code run faster by JIT-compiling it into optimized kernels, all while required minimal code changes.
Under the hood, torch.compile
captures PyTorch
programs via TorchDynamo
, canonicalizes over 2,000 PyTorch
operators via PrimTorch
, and finally generates fast code out of it across multiple accelerators and backends via the deep learning compiler TorchInductor
.
Note
See here for a general tutorial on how to leverage torch.compile
, and here for a description of its interface.
In this tutorial, we show how to optimize your custom PyG
model via torch.compile
.
Note
From PyG
2.5 (and onwards), torch.compile
is now fully compatible with all PyG
GNN layers. If you are on an earlier version of PyG
, consider using torch_geometric.compile
instead.
Once you have a PyG
model defined, simply wrap it with torch.compile
to obtain its optimized version:
import torch
from torch_geometric.nn import GraphSAGE
model = GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)
model = model.to(device)
model = torch.compile(model)
and execute it as usual:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root, name="Cora")
data = dataset[0].to(device)
out = model(data.x, data.edge_index)
The torch.compile
method provides two important arguments to be aware of:
Most of the mini-batches observed in
PyG
are dynamic by nature, meaning that their shape varies across different mini-batches. For these scenarios, we can enforce dynamic shape tracing inPyTorch
via thedynamic=True
argument:torch.compile(model, dynamic=True)
With this,
PyTorch
will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change across mini-batches changes. Note that whendynamic
is set toFalse
,PyTorch
will never generate dynamic kernels, and thus only works when graph sizes are guaranteed to never change (e.g., in full-batch training on small graphs). By default,dynamic
is set toNone
inPyTorch
>= 2.1.0
, andPyTorch
will automatically detect if dynamism has occured. Note that support for dynamic shape tracing requiresPyTorch
>= 2.1.0
to be installed.In order to maximize speedup, graphs breaks in the compiled model should be limited. We can force compilation to raise an error upon the first graph break encountered by using the
fullgraph=True
argument:torch.compile(model, fullgraph=True)
It is generally a good practice to confirm that your written model does not contain any graph breaks. Importantly, there exists a few operations in
PyG
that will currently lead to graph breaks (but workaround exists), e.g.:~torch_geometric.nn.pool.global_mean_pool
(and other pooling operators) perform device synchronization in case the batch sizesize
is not passed, leading to a graph break.~torch_geometric.utils.remove_self_loops
and~torch_geometric.utils.add_remaining_self_loops
mask the givenedge_index
, leading to a device synchronization to compute its final output shape. As such, we recommend to augment your graph before inputting it into your GNN, e.g., via the~torch_geometric.transforms.AddSelfLoops
or~torch_geometric.transforms.GCNNorm
transformations, and settingadd_self_loops=False
/normalize=False
when initializing layers such as~torch_geometric.nn.conv.GCNConv
.
We have incorporated multiple examples in examples/compile
that further show the practical usage of torch.compile
:
- Node Classification via
~torch_geometric.nn.models.GCN
(dynamic=False
) - Graph Classification via
~torch_geometric.nn.models.GIN
(dynamic=True
)
If you notice that torch.compile
fails for a certain PyG
model, do not hesitate to reach out either on null
GitHub or null
Slack. We are very eager to improve torch.compile
support across the whole PyG
code base.
torch.compile
works fantastically well for many PyG
models. Overall, we observe runtime improvements of nearly up to 300%.
Specifically, we benchmark ~torch_geometric.nn.models.GCN
, ~torch_geometric.nn.models.GraphSAGE
and ~torch_geometric.nn.models.GIN
and compare runtimes obtained from traditional eager mode and torch.compile
. We use a synthetic graph with 10,000 nodes and 200,000 edges, and a hidden feature dimensionality of 64. We report runtimes over 500 optimization steps:
Model | Mode | Forward | Backward | Total | Speedup |
---|---|---|---|---|---|
~torch_geometric.nn.models.GCN |
Eager | 2.6396s | 2.1697s | 4.8093s | |
~torch_geometric.nn.models.GCN |
Compiled | 1.1082s | 0.5896s | 1.6978s | 2.83x |
~torch_geometric.nn.models.GraphSAGE |
Eager | 1.6023s | 1.6428s | 3.2451s | |
~torch_geometric.nn.models.GraphSAGE |
Compiled | 0.7033s | 0.7465s | 1.4498s | 2.24x |
~torch_geometric.nn.models.GIN |
Eager | 1.6701s | 1.6990s | 3.3690s | |
~torch_geometric.nn.models.GIN |
Compiled | 0.7320s | 0.7407s | 1.4727s | 2.29x |
To reproduce these results, run
python test/nn/models/test_basic_gnn.py
from the root folder of your checked out PyG
repository from GitHub
.