Skip to content

Commit

Permalink
add warnings when radius or radius_graph are called on CPU with `…
Browse files Browse the repository at this point in the history
…max_num_neighbors` (#9076)

Warn users of the issue described
[here](#9036).

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people committed Mar 25, 2024
1 parent 8a9ace7 commit 37b7616
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ def radius(
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`torch.Tensor`
.. warning::
The CPU implementation of :meth:`radius` with :obj:`max_num_neighbors`
is biased towards certain quadrants.
Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving
inputs to GPU before proceeding.
"""
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
return torch_cluster.radius(x, y, r, batch_x, batch_y,
Expand Down Expand Up @@ -268,6 +275,13 @@ def radius_graph(
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`torch.Tensor`
.. warning::
The CPU implementation of :meth:`radius_graph` with
:obj:`max_num_neighbors` is biased towards certain quadrants.
Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving
inputs to GPU before proceeding.
"""
if batch is not None and x.device != batch.device:
warnings.warn("Input tensor 'x' and 'batch' are on different devices "
Expand Down

0 comments on commit 37b7616

Please sign in to comment.