Skip to content

Commit

Permalink
Support torch.compile with EdgeIndex (#9007)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 3, 2024
1 parent 3548318 commit 8a84823
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 10 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249/))
- Added support for `torch.compile` in combination with `EdgeIndex` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
- Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249))
- Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983))
- Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952))

Expand Down
38 changes: 36 additions & 2 deletions test/nn/test_compile_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch import Tensor

from torch_geometric import EdgeIndex
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
Expand All @@ -28,17 +29,50 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
@withCUDA
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.0.0')
@withPackage('torch>=2.1.0')
@pytest.mark.parametrize('Conv', [GCNConv, SAGEConv])
def test_compile_conv(device, Conv):
import torch._dynamo as dynamo

x = torch.randn(10, 16, device=device)
edge_index = torch.randint(0, x.size(0), (2, 40), device=device)

conv = Conv(16, 32).to(device)
if Conv == GCNConv:
conv = Conv(16, 32, add_self_loops=False).to(device)
else:
conv = Conv(16, 32).to(device)

explanation = dynamo.explain(conv)(x, edge_index)
assert explanation.graph_break_count == 0

out = torch.compile(conv)(x, edge_index)
assert torch.allclose(conv(x, edge_index), out, atol=1e-6)


@withCUDA
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.2.0')
@pytest.mark.parametrize('Conv', [GCNConv, SAGEConv])
def test_compile_conv_edge_index(device, Conv):
import torch._dynamo as dynamo

x = torch.randn(10, 16, device=device)
edge_index = torch.randint(0, x.size(0), (2, 40), device=device)
edge_index = EdgeIndex(edge_index, sparse_size=(10, 10))

if Conv == GCNConv:
conv = Conv(16, 32, normalize=False).to(device)
else:
conv = Conv(16, 32).to(device)

explanation = dynamo.explain(conv)(x, edge_index)
assert explanation.graph_break_count == 0

out = torch.compile(conv, fullgraph=True)(x, edge_index)
assert torch.allclose(conv(x, edge_index), out, atol=1e-6)


if __name__ == '__main__':
import argparse

Expand Down
17 changes: 10 additions & 7 deletions test/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,29 +1144,32 @@ def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:


@onlyLinux
@withPackage('torch>=2.1.0')
@withPackage('torch>=2.2.0')
def test_compile():
import torch._dynamo as dynamo

class Model(torch.nn.Module):
def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:
row, col = edge_index[0], edge_index[1]
x_j = x[row]
out = scatter(x_j, col, dim_size=edge_index.num_cols)
x_j = x[edge_index[0]]
out = scatter(x_j, edge_index[1], dim_size=edge_index.num_cols)
return out

x = torch.randn(3, 8)
# Test that `num_cols` gets picked up by making last node isolated.
edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 0, 1]], sparse_size=(3, 3))
edge_index = EdgeIndex(
[[0, 1, 1, 2], [1, 0, 0, 1]],
sparse_size=(3, 3),
sort_order='row',
).fill_cache_()

model = Model()
expected = model(x, edge_index)
assert expected.size() == (3, 8)

explanation = dynamo.explain(model)(x, edge_index)
assert explanation.graph_break_count <= 0
assert explanation.graph_break_count == 0

compiled_model = torch.compile(model)
compiled_model = torch.compile(model, fullgraph=True)
out = compiled_model(x, edge_index)
assert torch.allclose(out, expected)

Expand Down
24 changes: 24 additions & 0 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ class EdgeIndex(Tensor):
# See "https://pytorch.org/docs/stable/notes/extending.html"
# for a basic tutorial on how to subclass `torch.Tensor`.

# The underlying tensor representation:
_data: Optional[Tensor] = None

# The size of the underlying sparse matrix:
_sparse_size: Tuple[Optional[int], Optional[int]] = (None, None)

Expand Down Expand Up @@ -340,6 +343,8 @@ def __new__(

# Attach metadata:
assert isinstance(out, EdgeIndex)
if torch_geometric.typing.WITH_PT22:
out._data = data
out._sparse_size = sparse_size
out._sort_order = None if sort_order is None else SortOrder(sort_order)
out._is_undirected = is_undirected
Expand Down Expand Up @@ -1077,6 +1082,25 @@ def sparse_narrow(
edge_index._indptr = colptr
return edge_index

def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
if not torch_geometric.typing.WITH_PT22:
raise RuntimeError("'torch.compile' with 'EdgeIndex' only "
"supported from PyTorch 2.2 onwards")
assert self._data is not None
# TODO Add `_T_index`.
attrs = ['_data', '_indptr', '_T_perm', '_T_indptr']
return attrs, ()

@staticmethod
def __tensor_unflatten__(
inner_tensors: Tuple[Any],
ctx: Tuple[Any, ...],
) -> 'EdgeIndex':
if not torch_geometric.typing.WITH_PT22:
raise RuntimeError("'torch.compile' with 'EdgeIndex' only "
"supported from PyTorch 2.2 onwards")
raise NotImplementedError

@classmethod
def __torch_function__(
cls: Type,
Expand Down

0 comments on commit 8a84823

Please sign in to comment.