Skip to content

Commit

Permalink
Added EdgeIndex support in MessagePassing (#9008)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 3, 2024
1 parent 8a84823 commit fefa636
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `EdgeIndex` in `MessagePassing` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
- Added support for `torch.compile` in combination with `EdgeIndex` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
- Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249))
- Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983))
Expand Down
18 changes: 15 additions & 3 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import copy
import os.path as osp
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import pytest
import torch
from torch import Tensor
from torch.nn import Linear

import torch_geometric.typing
from torch_geometric import EdgeIndex
from torch_geometric.nn import MessagePassing, aggr
from torch_geometric.typing import (
Adj,
Expand Down Expand Up @@ -57,8 +58,8 @@ def forward(

return out

def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return edge_weight.view(-1, 1) * x_j
def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:
return spmm(adj_t, x[0], reduce=self.aggr)
Expand Down Expand Up @@ -129,6 +130,17 @@ def test_my_conv_basic():
assert torch_adj_t.grad is not None


def test_my_conv_edge_index():
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
edge_index = EdgeIndex(edge_index, sparse_size=(4, 4), sort_order='col')

conv = MyConv(8, 32)

out = conv(x, edge_index)
assert out.size() == (4, 32)


def test_my_conv_out_of_bounds():
x = torch.randn(3, 8)
value = torch.randn(4)
Expand Down
8 changes: 4 additions & 4 deletions test/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,11 +1031,11 @@ def test_sparse_resize(device):
assert out.sparse_size() == (3, 3)
assert out._indptr.equal(tensor([0, 1, 3, 4], device=device))
assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device))
out = out.sparse_resize_((4, 5))
out = out.sparse_resize_(4, 5)
assert out.sparse_size() == (4, 5)
assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device))
assert out._T_indptr.equal(tensor([0, 1, 3, 4, 4, 4], device=device))
out = out.sparse_resize_((3, 3))
out = out.sparse_resize_(3, 3)
assert out.sparse_size() == (3, 3)
assert out._indptr is None
assert out._T_indptr is None
Expand All @@ -1044,11 +1044,11 @@ def test_sparse_resize(device):
assert out.sparse_size() == (3, 3)
assert out._indptr.equal(tensor([0, 1, 3, 4], device=device))
assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device))
out = out.sparse_resize_((4, 5))
out = out.sparse_resize_(4, 5)
assert out.sparse_size() == (4, 5)
assert out._indptr.equal(tensor([0, 1, 3, 4, 4, 4], device=device))
assert out._T_indptr.equal(tensor([0, 1, 3, 4, 4], device=device))
out = out.sparse_resize_((3, 3))
out = out.sparse_resize_(3, 3)
assert out.sparse_size() == (3, 3)
assert out._indptr is None
assert out._T_indptr is None
Expand Down
39 changes: 25 additions & 14 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,22 +522,33 @@ def get_sparse_size(
return torch.Size((self.get_sparse_size(0), self.get_sparse_size(1)))

def sparse_resize_( # type: ignore
self,
sparse_size: Tuple[int, int],
self,
num_rows: Optional[int],
num_cols: Optional[int],
) -> 'EdgeIndex':
r"""Assigns or re-assigns the size of the underlying sparse matrix.
Args:
sparse_size (tuple[int, int]): The size of the sparse matrix.
num_rows (int, optional): The number of rows.
num_cols (int, optional): The number of columns.
"""
num_rows, num_cols = sparse_size

if self.is_undirected and num_rows != num_cols:
raise ValueError(f"'EdgeIndex' is undirected but received a "
f"non-symmetric size (got {list(sparse_size)})")

def _modify_ptr(ptr: Optional[Tensor], size: int) -> Optional[Tensor]:
if ptr is None:
if self.is_undirected:
if num_rows is not None and num_cols is None:
num_cols = num_rows
elif num_cols is not None and num_rows is None:
num_rows = num_cols

if num_rows is not None and num_rows != num_cols:
raise ValueError(f"'EdgeIndex' is undirected but received a "
f"non-symmetric size "
f"(got [{num_rows}, {num_cols}])")

def _modify_ptr(
ptr: Optional[Tensor],
size: Optional[int],
) -> Optional[Tensor]:

if ptr is None or size is None:
return None

if ptr.numel() - 1 == size:
Expand All @@ -560,7 +571,7 @@ def _modify_ptr(ptr: Optional[Tensor], size: int) -> Optional[Tensor]:
self._indptr = _modify_ptr(self._indptr, num_cols)
self._T_indptr = _modify_ptr(self._T_indptr, num_rows)

self._sparse_size = sparse_size
self._sparse_size = (num_rows, num_cols)

return self

Expand Down Expand Up @@ -1083,7 +1094,7 @@ def sparse_narrow(
return edge_index

def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
if not torch_geometric.typing.WITH_PT22:
if not torch_geometric.typing.WITH_PT22: # pragma: no cover
raise RuntimeError("'torch.compile' with 'EdgeIndex' only "
"supported from PyTorch 2.2 onwards")
assert self._data is not None
Expand All @@ -1096,7 +1107,7 @@ def __tensor_unflatten__(
inner_tensors: Tuple[Any],
ctx: Tuple[Any, ...],
) -> 'EdgeIndex':
if not torch_geometric.typing.WITH_PT22:
if not torch_geometric.typing.WITH_PT22: # pragma: no cover
raise RuntimeError("'torch.compile' with 'EdgeIndex' only "
"supported from PyTorch 2.2 onwards")
raise NotImplementedError
Expand Down
7 changes: 7 additions & 0 deletions torch_geometric/nn/conv/collect.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from typing import List, NamedTuple, Optional, Union
import torch
from torch import Tensor

from torch_geometric import EdgeIndex
from torch_geometric.utils import is_torch_sparse_tensor
from torch_geometric.utils.sparse import ptr2index
from torch_geometric.typing import SparseTensor
Expand Down Expand Up @@ -65,7 +66,13 @@ def {{collect_name}}(
{%- endif %}
edge_index_i = edge_index[i]
edge_index_j = edge_index[j]

ptr = None
if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
if i == 0 and edge_index.is_sorted_by_row:
(ptr, _), _ = edge_index.get_csr()
elif i == 1 and edge_index.is_sorted_by_col:
(ptr, _), _ = edge_index.get_csc()

elif isinstance(edge_index, SparseTensor):
{%- if 'edge_index' in collect_param_dict %}
Expand Down
11 changes: 10 additions & 1 deletion torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import Tensor
from torch.utils.hooks import RemovableHandle

from torch_geometric import is_compiling
from torch_geometric import EdgeIndex, is_compiling
from torch_geometric.inspector import Inspector, Signature
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
Expand Down Expand Up @@ -251,6 +251,9 @@ def _check_input(
size: Optional[Tuple[int, int]],
) -> List[Optional[int]]:

if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
return [edge_index.num_rows, edge_index.num_cols]

if is_sparse(edge_index):
if self.flow == 'target_to_source':
raise ValueError(
Expand Down Expand Up @@ -421,7 +424,13 @@ def _collect(
out['edge_index'] = edge_index
out['edge_index_i'] = edge_index[i]
out['edge_index_j'] = edge_index[j]

out['ptr'] = None
if isinstance(edge_index, EdgeIndex):
if i == 0 and edge_index.is_sorted_by_row:
(out['ptr'], _), _ = edge_index.get_csr()
elif i == 1 and edge_index.is_sorted_by_col:
(out['ptr'], _), _ = edge_index.get_csc()

elif isinstance(edge_index, SparseTensor):
row, col, value = edge_index.coo()
Expand Down
11 changes: 10 additions & 1 deletion torch_geometric/utils/_trim_to_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import Tensor

from torch_geometric import EdgeIndex
from torch_geometric.typing import (
Adj,
EdgeType,
Expand Down Expand Up @@ -202,11 +203,19 @@ def trim_adj(
return edge_index

if isinstance(edge_index, Tensor):
return edge_index.narrow(
edge_index = edge_index.narrow(
dim=1,
start=0,
length=edge_index.size(1) - num_sampled_edges_per_hop[-layer],
)
if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
num_rows, num_cols = edge_index.sparse_size()
if num_rows is not None:
num_rows -= num_sampled_src_nodes_per_hop[-layer]
if num_cols is not None:
num_cols -= num_sampled_dst_nodes_per_hop[-layer]
edge_index.sparse_resize_(num_rows, num_cols)
return edge_index

elif isinstance(edge_index, SparseTensor):
size = (
Expand Down

0 comments on commit fefa636

Please sign in to comment.