Skip to content

Commit

Permalink
Integrate EdgeIndex in cugraph conv layers (#8937)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people committed Feb 20, 2024
1 parent 6517997 commit 844a9dc
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 81 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937))
- Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918))

### Deprecated
Expand Down
8 changes: 6 additions & 2 deletions test/nn/conv/cugraph/test_cugraph_gat_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from torch_geometric import EdgeIndex
from torch_geometric.nn import CuGraphGATConv, GATConv
from torch_geometric.testing import onlyCUDA, withPackage

Expand Down Expand Up @@ -37,8 +38,11 @@ def test_gat_conv_equality(bias, bipartite, concat, heads, max_num_neighbors):
else:
out1 = conv1(x, edge_index)

csc = CuGraphGATConv.to_csc(edge_index, size)
out2 = conv2(x, csc, max_num_neighbors=max_num_neighbors)
out2 = conv2(
x,
EdgeIndex(edge_index, sparse_size=size),
max_num_neighbors=max_num_neighbors,
)
assert torch.allclose(out1, out2, atol=1e-3)

grad_output = torch.rand_like(out1)
Expand Down
9 changes: 7 additions & 2 deletions test/nn/conv/cugraph/test_cugraph_rgcn_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from torch_geometric import EdgeIndex
from torch_geometric.nn import CuGraphRGCNConv
from torch_geometric.nn import FastRGCNConv as RGCNConv
from torch_geometric.testing import onlyCUDA, withPackage
Expand Down Expand Up @@ -41,8 +42,12 @@ def test_rgcn_conv_equality(aggr, bias, bipartite, max_num_neighbors,
else:
out1 = conv1(x, edge_index, edge_type)

csc, edge_type = CuGraphRGCNConv.to_csc(edge_index, size, edge_type)
out2 = conv2(x, csc, edge_type, max_num_neighbors=max_num_neighbors)
out2 = conv2(
x,
EdgeIndex(edge_index, sparse_size=size),
edge_type,
max_num_neighbors=max_num_neighbors,
)
assert torch.allclose(out1, out2, atol=1e-3)

grad_out = torch.rand_like(out1)
Expand Down
8 changes: 6 additions & 2 deletions test/nn/conv/cugraph/test_cugraph_sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from torch_geometric import EdgeIndex
from torch_geometric.nn import CuGraphSAGEConv, SAGEConv
from torch_geometric.testing import onlyCUDA, withPackage

Expand Down Expand Up @@ -42,8 +43,11 @@ def test_sage_conv_equality(aggr, bias, bipartite, max_num_neighbors,
else:
out1 = conv1(x, edge_index)

csc = CuGraphSAGEConv.to_csc(edge_index, size)
out2 = conv2(x, csc, max_num_neighbors=max_num_neighbors)
out2 = conv2(
x,
EdgeIndex(edge_index, sparse_size=size),
max_num_neighbors=max_num_neighbors,
)
assert torch.allclose(out1, out2, atol=1e-6)

grad_out = torch.rand_like(out1)
Expand Down
85 changes: 24 additions & 61 deletions torch_geometric/nn/conv/cugraph/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import warnings
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional

import torch
from torch import Tensor

from torch_geometric.utils import index_sort
from torch_geometric.utils.sparse import index2ptr
from torch_geometric import EdgeIndex

try: # pragma: no cover
LEGACY_MODE = False
Expand Down Expand Up @@ -45,63 +43,28 @@ def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
pass

@staticmethod
def to_csc(
edge_index: Tensor,
size: Optional[Tuple[int, int]] = None,
edge_attr: Optional[Tensor] = None,
) -> Union[Tuple[Tensor, Tensor, int], Tuple[Tuple[Tensor, Tensor, int],
Tensor]]:
r"""Returns a CSC representation of an :obj:`edge_index` tensor to be
used as input to a :class:`CuGraphModule`.
Args:
edge_index (torch.Tensor): The edge indices.
size ((int, int), optional): The shape of :obj:`edge_index` in each
dimension. (default: :obj:`None`)
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
"""
if size is None:
warnings.warn(f"Inferring the graph size from 'edge_index' causes "
f"a decline in performance and does not work for "
f"bipartite graphs. To suppress this warning, pass "
f"the 'size' explicitly in '{__name__}.to_csc()'.")
num_src_nodes = num_dst_nodes = int(edge_index.max()) + 1
else:
num_src_nodes, num_dst_nodes = size

row, col = edge_index
col, perm = index_sort(col, max_value=num_dst_nodes)
row = row[perm]

colptr = index2ptr(col, num_dst_nodes)

if edge_attr is not None:
return (row, colptr, num_src_nodes), edge_attr[perm]

return row, colptr, num_src_nodes

def get_cugraph(
self,
csc: Tuple[Tensor, Tensor, int],
edge_index: EdgeIndex,
max_num_neighbors: Optional[int] = None,
) -> Any:
r"""Constructs a :obj:`cugraph` graph object from CSC representation.
Supports both bipartite and non-bipartite graphs.
Args:
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`CuGraphModule.to_csc` method to convert an
:obj:`edge_index` representation to the desired format.
edge_index (EdgeIndex): The edge indices.
max_num_neighbors (int, optional): The maximum number of neighbors
of a target node. It is only effective when operating in a
bipartite graph. When not given, will be computed on-the-fly,
leading to slightly worse performance. (default: :obj:`None`)
"""
row, colptr, num_src_nodes = csc
if not isinstance(edge_index, EdgeIndex):
raise ValueError(f"'edge_index' needs to be of type 'EdgeIndex' "
f"(got {type(edge_index)})")

edge_index = edge_index.sort_by('col')[0]
num_src_nodes = edge_index.get_sparse_size(0)
(colptr, row), _ = edge_index.get_csc()

if not row.is_cuda:
raise RuntimeError(f"'{self.__class__.__name__}' requires GPU-"
Expand All @@ -125,7 +88,7 @@ def get_cugraph(

def get_typed_cugraph(
self,
csc: Tuple[Tensor, Tensor, int],
edge_index: EdgeIndex,
edge_type: Tensor,
num_edge_types: Optional[int] = None,
max_num_neighbors: Optional[int] = None,
Expand All @@ -135,11 +98,7 @@ def get_typed_cugraph(
Supports both bipartite and non-bipartite graphs.
Args:
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`CuGraphModule.to_csc` method to convert an
:obj:`edge_index` representation to the desired format.
edge_index (EdgeIndex): The edge indices.
edge_type (torch.Tensor): The edge type.
num_edge_types (int, optional): The maximum number of edge types.
When not given, will be computed on-the-fly, leading to
Expand All @@ -152,7 +111,15 @@ def get_typed_cugraph(
if num_edge_types is None:
num_edge_types = int(edge_type.max()) + 1

row, colptr, num_src_nodes = csc
if not isinstance(edge_index, EdgeIndex):
raise ValueError(f"'edge_index' needs to be of type 'EdgeIndex' "
f"(got {type(edge_index)})")

edge_index, perm = edge_index.sort_by('col')
edge_type = edge_type[perm]
num_src_nodes = edge_index.get_sparse_size(0)
(colptr, row), _ = edge_index.get_csc()

edge_type = edge_type.int()

if num_src_nodes != colptr.numel() - 1: # Bipartite graph:
Expand Down Expand Up @@ -181,18 +148,14 @@ def get_typed_cugraph(
def forward(
self,
x: Tensor,
csc: Tuple[Tensor, Tensor, int],
edge_index: EdgeIndex,
max_num_neighbors: Optional[int] = None,
) -> Tensor:
r"""Runs the forward pass of the module.
Args:
x (torch.Tensor): The node features.
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`CuGraphModule.to_csc` method to convert an
:obj:`edge_index` representation to the desired format.
edge_index (EdgeIndex): The edge indices.
max_num_neighbors (int, optional): The maximum number of neighbors
of a target node. It is only effective when operating in a
bipartite graph. When not given, the value will be computed
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/nn/conv/cugraph/gat_conv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional, Tuple
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Linear, Parameter

from torch_geometric import EdgeIndex
from torch_geometric.nn.conv.cugraph import CuGraphModule
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE
from torch_geometric.nn.inits import zeros
Expand Down Expand Up @@ -65,10 +66,10 @@ def reset_parameters(self):
def forward(
self,
x: Tensor,
csc: Tuple[Tensor, Tensor, int],
edge_index: EdgeIndex,
max_num_neighbors: Optional[int] = None,
) -> Tensor:
graph = self.get_cugraph(csc, max_num_neighbors)
graph = self.get_cugraph(edge_index, max_num_neighbors)

x = self.lin(x)

Expand Down
14 changes: 6 additions & 8 deletions torch_geometric/nn/conv/cugraph/rgcn_conv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional, Tuple
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric import EdgeIndex
from torch_geometric.nn.conv.cugraph import CuGraphModule
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE
from torch_geometric.nn.inits import glorot, zeros
Expand Down Expand Up @@ -76,27 +77,24 @@ def reset_parameters(self):
def forward(
self,
x: Tensor,
csc: Tuple[Tensor, Tensor, int],
edge_index: EdgeIndex,
edge_type: Tensor,
max_num_neighbors: Optional[int] = None,
) -> Tensor:
r"""Runs the forward pass of the module.
Args:
x (torch.Tensor): The node features.
csc ((torch.Tensor, torch.Tensor)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr)`. Use the :meth:`to_csc` method to convert
an :obj:`edge_index` representation to the desired format.
edge_index (EdgeIndex): The edge indices.
edge_type (torch.Tensor): The edge type.
max_num_neighbors (int, optional): The maximum number of neighbors
of a target node. It is only effective when operating in a
bipartite graph.. When not given, the value will be computed
on-the-fly, leading to slightly worse performance.
(default: :obj:`None`)
"""
graph = self.get_typed_cugraph(csc, edge_type, self.num_relations,
max_num_neighbors)
graph = self.get_typed_cugraph(edge_index, edge_type,
self.num_relations, max_num_neighbors)

out = RGCNConvAgg(x, self.comp, graph, concat_own=self.root_weight,
norm_by_out_degree=bool(self.aggr == 'mean'))
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/nn/conv/cugraph/sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional, Tuple
from typing import Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear

from torch_geometric import EdgeIndex
from torch_geometric.nn.conv.cugraph import CuGraphModule
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE

Expand Down Expand Up @@ -68,10 +69,10 @@ def reset_parameters(self):
def forward(
self,
x: Tensor,
csc: Tuple[Tensor, Tensor, int],
edge_index: EdgeIndex,
max_num_neighbors: Optional[int] = None,
) -> Tensor:
graph = self.get_cugraph(csc, max_num_neighbors)
graph = self.get_cugraph(edge_index, max_num_neighbors)

if self.project:
x = self.pre_lin(x).relu()
Expand Down

0 comments on commit 844a9dc

Please sign in to comment.