Skip to content

Commit

Permalink
Only use segment in case torch.use_deterministic_algorithms is set (
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 4, 2024
1 parent fefa636 commit d491f43
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Default to `scatter` operations in `MessagePassing` in case `torch.use_deterministic_algorithms` is not set ([#9009](https://github.com/pyg-team/pytorch_geometric/pull/9009))
- Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001))
- Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937))
- Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918))
Expand Down
9 changes: 1 addition & 8 deletions test/nn/aggr/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pytest
import torch

import torch_geometric.typing
from torch_geometric.nn import MLP
from torch_geometric.nn.aggr import AttentionalAggregation

Expand All @@ -21,9 +19,4 @@ def test_attentional_aggregation():

out = aggr(x, index)
assert out.size() == (3, channels)

if not torch_geometric.typing.WITH_TORCH_SCATTER:
with pytest.raises(NotImplementedError, match="requires 'index'"):
aggr(x, ptr=ptr)
else:
assert torch.allclose(out, aggr(x, ptr=ptr))
assert torch.allclose(out, aggr(x, ptr=ptr))
20 changes: 3 additions & 17 deletions test/nn/aggr/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import torch

import torch_geometric.typing
from torch_geometric.nn import (
MaxAggregation,
MeanAggregation,
Expand Down Expand Up @@ -53,10 +52,7 @@ def test_basic_aggregation(Aggregation):
assert out.size() == (3, x.size(1))

if isinstance(aggr, MulAggregation):
with pytest.raises(NotImplementedError, match="requires 'index'"):
aggr(x, ptr=ptr)
elif not torch_geometric.typing.WITH_TORCH_SCATTER:
with pytest.raises(NotImplementedError, match="requires 'index'"):
with pytest.raises(RuntimeError, match="requires 'index'"):
aggr(x, ptr=ptr)
else:
assert torch.allclose(out, aggr(x, ptr=ptr))
Expand Down Expand Up @@ -100,12 +96,7 @@ def test_learnable_aggregation(Aggregation, learn):

out = aggr(x, index)
assert out.size() == (3, x.size(1))

if not torch_geometric.typing.WITH_TORCH_SCATTER:
with pytest.raises(NotImplementedError, match="requires 'index'"):
aggr(x, ptr=ptr)
else:
assert torch.allclose(out, aggr(x, ptr=ptr))
assert torch.allclose(out, aggr(x, ptr=ptr))

if learn:
out.mean().backward()
Expand All @@ -127,12 +118,7 @@ def test_learnable_channels_aggregation(Aggregation):

out = aggr(x, index)
assert out.size() == (3, x.size(1))

if not torch_geometric.typing.WITH_TORCH_SCATTER:
with pytest.raises(NotImplementedError, match="requires 'index'"):
aggr(x, ptr=ptr)
else:
assert torch.allclose(out, aggr(x, ptr=ptr))
assert torch.allclose(out, aggr(x, ptr=ptr))

out.mean().backward()
for param in aggr.parameters():
Expand Down
8 changes: 1 addition & 7 deletions test/nn/aggr/test_multi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import torch

import torch_geometric.typing
from torch_geometric.nn import MultiAggregation


Expand Down Expand Up @@ -38,12 +37,7 @@ def test_multi_aggr(multi_aggr_tuple):

out = aggr(x, index)
assert out.size() == (4, expand * x.size(1))

if not torch_geometric.typing.WITH_TORCH_SCATTER:
with pytest.raises(NotImplementedError, match="requires 'index'"):
aggr(x, ptr=ptr)
else:
assert torch.allclose(out, aggr(x, ptr=ptr))
assert torch.allclose(out, aggr(x, ptr=ptr))

jit = torch.jit.script(aggr)
assert torch.allclose(out, jit(x, index))
2 changes: 1 addition & 1 deletion test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_basic_gnn_inference(get_dataset, jk):
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.0.0')
def test_compile(device):
def test_compile_basic(device):
x = torch.randn(3, 8, device=device)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)

Expand Down
21 changes: 10 additions & 11 deletions test/utils/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,12 @@ def test_softmax():

out = softmax(src, index)
assert out.tolist() == [0.5, 0.5, 1, 1]
if CALCULATION_VIA_PTR_AVAILABLE:
assert softmax(src, ptr=ptr).tolist() == out.tolist()
else:
with pytest.raises(NotImplementedError, match="requires 'index'"):
softmax(src, ptr=ptr)
assert softmax(src, ptr=ptr).tolist() == out.tolist()

src = src.view(-1, 1)
out = softmax(src, index)
assert out.tolist() == [[0.5], [0.5], [1], [1]]
if CALCULATION_VIA_PTR_AVAILABLE:
assert softmax(src, None, ptr).tolist() == out.tolist()
assert softmax(src, ptr=ptr).tolist() == out.tolist()

jit = torch.jit.script(softmax)
assert torch.allclose(jit(src, index), out)
Expand Down Expand Up @@ -55,23 +50,27 @@ def test_softmax_dim():

src = torch.randn(4)
assert torch.allclose(softmax(src, index, dim=0), src.softmax(dim=0))
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0))
assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0))

src = torch.randn(4, 16)
assert torch.allclose(softmax(src, index, dim=0), src.softmax(dim=0))
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0))
assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0))

src = torch.randn(4, 4)
assert torch.allclose(softmax(src, index, dim=-1), src.softmax(dim=-1))
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=-1), src.softmax(-1))
else:
with pytest.raises(ImportError, match="requires the 'torch-scatter'"):
softmax(src, ptr=ptr, dim=-1)

src = torch.randn(4, 4, 16)
assert torch.allclose(softmax(src, index, dim=1), src.softmax(dim=1))
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=1), src.softmax(dim=1))
else:
with pytest.raises(ImportError, match="requires the 'torch-scatter'"):
softmax(src, ptr=ptr, dim=1)


if __name__ == '__main__':
Expand Down
23 changes: 14 additions & 9 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Optional, Tuple
from typing import Final, Optional, Tuple

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.utils import scatter, segment, to_dense_batch

Expand Down Expand Up @@ -61,6 +59,13 @@ class Aggregation(torch.nn.Module):
- **output:** graph features :math:`(*, |\mathcal{G}|, F_{out})` or
node features :math:`(*, |\mathcal{V}|, F_{out})`
"""
def __init__(self) -> None:
super().__init__()

self._deterministic: Final[bool] = (
torch.are_deterministic_algorithms_enabled()
or torch.is_deterministic_algorithms_warn_only_enabled())

def forward(
self,
x: Tensor,
Expand Down Expand Up @@ -171,14 +176,14 @@ def reduce(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2, reduce: str = 'sum') -> Tensor:

if (ptr is not None and torch_geometric.typing.WITH_TORCH_SCATTER
and not is_compiling()):
ptr = expand_left(ptr, dim, dims=x.dim())
return segment(x, ptr, reduce=reduce)
if ptr is not None:
if index is None or self._deterministic:
ptr = expand_left(ptr, dim, dims=x.dim())
return segment(x, ptr, reduce=reduce)

if index is None:
raise NotImplementedError(
"Aggregation requires 'index' to be specified")
raise RuntimeError("Aggregation requires 'index' to be specified")

return scatter(x, index, dim, dim_size, reduce)

def to_dense_batch(
Expand Down
8 changes: 5 additions & 3 deletions torch_geometric/utils/_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@ def segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor:
if not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling():
return _torch_segment(src, ptr, reduce)

if torch_geometric.typing.WITH_PT20 and src.is_cuda and reduce == 'mean':
if (ptr.dim() == 1 and torch_geometric.typing.WITH_PT20 and src.is_cuda
and reduce == 'mean'):
return _torch_segment(src, ptr, reduce)

# TODO Fallback to `scatter` if deterministic algorithms are turned off.

return torch_scatter.segment_csr(src, ptr, reduce=reduce)


def _torch_segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor:
if not torch_geometric.typing.WITH_PT20:
raise ImportError("'segment' requires the 'torch-scatter' package")
if ptr.dim() > 1:
raise ImportError("'segment' in an arbitrary dimension "
"requires the 'torch-scatter' package")

if reduce == 'min' or reduce == 'max':
reduce = f'a{reduce}' # `amin` or `amax`
Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/utils/_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def softmax(
and not is_compiling()): # pragma: no cover
return pyg_lib.ops.softmax_csr(src, ptr, dim)

if (ptr is not None and torch_geometric.typing.WITH_TORCH_SCATTER
and not is_compiling()):
if (ptr is not None and
(ptr.dim() == 1 or (ptr.dim() > 1 and index is None) or
(torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()))):

dim = dim + src.dim() if dim < 0 else dim
size = ([1] * dim) + [-1]
count = ptr[1:] - ptr[:-1]
Expand Down
17 changes: 9 additions & 8 deletions torch_geometric/utils/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,14 +460,15 @@ def add_self_loops( # noqa: F811
size = (N, N)

device = edge_index.device
if torch.jit.is_scripting():
loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1)
else:
loop_index = EdgeIndex(
if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
loop_index: Tensor = EdgeIndex(
torch.arange(0, N, device=device).view(1, -1).repeat(2, 1),
sparse_size=(N, N),
is_undirected=True,
)
else:
loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1)

full_edge_index = torch.cat([edge_index, loop_index], dim=1)

if is_sparse:
Expand Down Expand Up @@ -623,14 +624,14 @@ def add_remaining_self_loops( # noqa: F811
mask = edge_index[0] != edge_index[1]

device = edge_index.device
if torch.jit.is_scripting():
loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1)
else:
loop_index = EdgeIndex(
if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
loop_index: Tensor = EdgeIndex(
torch.arange(0, N, device=device).view(1, -1).repeat(2, 1),
sparse_size=(N, N),
is_undirected=True,
)
else:
loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1)

if edge_attr is not None:

Expand Down

0 comments on commit d491f43

Please sign in to comment.