Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions csrc/cpu/radius_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers) {
int64_t max_num_neighbors, int64_t num_workers,
bool ignore_same_index) {

CHECK_CPU(x);
CHECK_INPUT(x.dim() == 2);
Expand Down Expand Up @@ -54,10 +55,14 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
size_t num_matches = mat_index.index->radiusSearch(
y_data + i * y.size(1), r * r, ret_matches, params);

for (size_t j = 0; j < std::min(num_matches, (size_t)max_num_neighbors);
j++) {
out_vec.push_back(ret_matches[j].first);
out_vec.push_back(i);
for (size_t j = 0, count = 0;
j < num_matches && count < (size_t)max_num_neighbors;
j++) {
if (!ignore_same_index || ret_matches[j].first != i) {
out_vec.push_back(ret_matches[j].first);
out_vec.push_back(i);
count++;
}
}
}

Expand Down Expand Up @@ -91,10 +96,14 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
size_t num_matches = mat_index.index->radiusSearch(
y_data + i * y.size(1), r * r, ret_matches, params);

for (size_t j = 0;
j < std::min(num_matches, (size_t)max_num_neighbors); j++) {
out_vec.push_back(x_start + ret_matches[j].first);
out_vec.push_back(i);
for (size_t j = 0, count = 0;
j < num_matches && count < (size_t)max_num_neighbors;
j++) {
if (!ignore_same_index || x_start + ret_matches[j].first != i) {
out_vec.push_back(x_start + ret_matches[j].first);
out_vec.push_back(i);
count++;
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/cpu/radius_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers);
int64_t max_num_neighbors, int64_t num_workers,
bool ignore_same_index);
10 changes: 6 additions & 4 deletions csrc/cuda/radius_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row,
int64_t *__restrict__ col, const scalar_t r, const int64_t n,
const int64_t m, const int64_t dim, const int64_t num_examples,
const int64_t max_num_neighbors) {
const int64_t max_num_neighbors,
const bool ignore_same_index) {

const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m)
Expand All @@ -29,7 +30,7 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
(x[n_x * dim + d] - y[n_y * dim + d]);
}

if (dist < r) {
if (dist < r && !(ignore_same_index && n_y == n_x)) {
row[n_y * max_num_neighbors + count] = n_y;
col[n_y * max_num_neighbors + count] = n_x;
count++;
Expand All @@ -43,7 +44,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, const double r,
const int64_t max_num_neighbors) {
const int64_t max_num_neighbors,
const bool ignore_same_index) {
CHECK_CUDA(x);
CHECK_CONTIGUOUS(x);
CHECK_INPUT(x.dim() == 2);
Expand Down Expand Up @@ -86,7 +88,7 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
ptr_x.value().data_ptr<int64_t>(),
ptr_y.value().data_ptr<int64_t>(), row.data_ptr<int64_t>(),
col.data_ptr<int64_t>(), r * r, x.size(0), y.size(0), x.size(1),
ptr_x.value().numel() - 1, max_num_neighbors);
ptr_x.value().numel() - 1, max_num_neighbors, ignore_same_index);
});

auto mask = row != -1;
Expand Down
3 changes: 2 additions & 1 deletion csrc/cuda/radius_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors);
int64_t max_num_neighbors,
bool ignore_same_index);
7 changes: 4 additions & 3 deletions csrc/radius.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ PyMODINIT_FUNC PyInit__radius_cpu(void) { return NULL; }
CLUSTER_API torch::Tensor radius(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers) {
int64_t max_num_neighbors, int64_t num_workers,
bool ignore_same_index) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors);
return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors, ignore_same_index);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors, num_workers);
return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors, num_workers, ignore_same_index);
}
}

Expand Down
41 changes: 41 additions & 0 deletions test/test_radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ def to_set(edge_index):
return set([(i, j) for i, j in edge_index.t().tolist()])


def to_degree(edge_index):
_, counts = torch.unique(edge_index[1], return_counts=True)
return counts.tolist()


def to_batch(nodes):
return [int(i / 4) for i in nodes]


@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
Expand Down Expand Up @@ -74,6 +83,38 @@ def test_radius_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])

edge_index = radius_graph(x, r=100, flow='source_to_target',
max_num_neighbors=1)
assert set(to_degree(edge_index)) == set([1])

x = tensor([
[-1, -1],
[-1, -1],
[-1, -1],
[-1, -1],
], dtype, device)

edge_index = radius_graph(x, r=100, flow='source_to_target',
max_num_neighbors=1)
assert set(to_degree(edge_index)) == set([1])

x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
], dtype, device)
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)

edge_index = radius_graph(x, r=100, batch=batch_x, flow='source_to_target',
max_num_neighbors=1)
assert set(to_degree(edge_index)) == set([1])
assert to_batch(edge_index[0]) == batch_x.tolist()


@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device):
Expand Down
15 changes: 8 additions & 7 deletions torch_cluster/radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def radius(
max_num_neighbors: int = 32,
num_workers: int = 1,
batch_size: Optional[int] = None,
ignore_same_index: bool = False
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
Expand Down Expand Up @@ -40,6 +41,9 @@ def radius(
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
ignore_same_index (bool, optional): If :obj:`True`, each element in
:obj:`y` ignores the point in :obj:`x` with the same index.
(default: :obj:`False`)

.. code-block:: python

Expand Down Expand Up @@ -80,7 +84,8 @@ def radius(
ptr_y = torch.bucketize(arange, batch_y)

return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors, num_workers)
max_num_neighbors, num_workers,
ignore_same_index)


def radius_graph(
Expand Down Expand Up @@ -133,15 +138,11 @@ def radius_graph(

assert flow in ['source_to_target', 'target_to_source']
edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers, batch_size)
max_num_neighbors,
num_workers, batch_size, not loop)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
else:
row, col = edge_index[0], edge_index[1]

if not loop:
mask = row != col
row, col = row[mask], col[mask]

return torch.stack([row, col], dim=0)