diff --git a/csrc/cpu/radius_cpu.cpp b/csrc/cpu/radius_cpu.cpp index 0115e85..41490f9 100644 --- a/csrc/cpu/radius_cpu.cpp +++ b/csrc/cpu/radius_cpu.cpp @@ -7,7 +7,8 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y, torch::optional ptr_x, torch::optional ptr_y, double r, - int64_t max_num_neighbors, int64_t num_workers) { + int64_t max_num_neighbors, int64_t num_workers, + bool ignore_same_index) { CHECK_CPU(x); CHECK_INPUT(x.dim() == 2); @@ -54,10 +55,14 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y, size_t num_matches = mat_index.index->radiusSearch( y_data + i * y.size(1), r * r, ret_matches, params); - for (size_t j = 0; j < std::min(num_matches, (size_t)max_num_neighbors); - j++) { - out_vec.push_back(ret_matches[j].first); - out_vec.push_back(i); + for (size_t j = 0, count = 0; + j < num_matches && count < (size_t)max_num_neighbors; + j++) { + if (!ignore_same_index || ret_matches[j].first != i) { + out_vec.push_back(ret_matches[j].first); + out_vec.push_back(i); + count++; + } } } @@ -91,10 +96,14 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y, size_t num_matches = mat_index.index->radiusSearch( y_data + i * y.size(1), r * r, ret_matches, params); - for (size_t j = 0; - j < std::min(num_matches, (size_t)max_num_neighbors); j++) { - out_vec.push_back(x_start + ret_matches[j].first); - out_vec.push_back(i); + for (size_t j = 0, count = 0; + j < num_matches && count < (size_t)max_num_neighbors; + j++) { + if (!ignore_same_index || x_start + ret_matches[j].first != i) { + out_vec.push_back(x_start + ret_matches[j].first); + out_vec.push_back(i); + count++; + } } } } diff --git a/csrc/cpu/radius_cpu.h b/csrc/cpu/radius_cpu.h index 639d130..0dbafe5 100644 --- a/csrc/cpu/radius_cpu.h +++ b/csrc/cpu/radius_cpu.h @@ -5,4 +5,5 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y, torch::optional ptr_x, torch::optional ptr_y, double r, - int64_t max_num_neighbors, int64_t num_workers); + int64_t max_num_neighbors, int64_t num_workers, + bool ignore_same_index); diff --git a/csrc/cuda/radius_cuda.cu b/csrc/cuda/radius_cuda.cu index 7efb2ff..29db910 100644 --- a/csrc/cuda/radius_cuda.cu +++ b/csrc/cuda/radius_cuda.cu @@ -13,7 +13,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row, int64_t *__restrict__ col, const scalar_t r, const int64_t n, const int64_t m, const int64_t dim, const int64_t num_examples, - const int64_t max_num_neighbors) { + const int64_t max_num_neighbors, + const bool ignore_same_index) { const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x; if (n_y >= m) @@ -29,7 +30,7 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, (x[n_x * dim + d] - y[n_y * dim + d]); } - if (dist < r) { + if (dist < r && !(ignore_same_index && n_y == n_x)) { row[n_y * max_num_neighbors + count] = n_y; col[n_y * max_num_neighbors + count] = n_x; count++; @@ -43,7 +44,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y, torch::optional ptr_x, torch::optional ptr_y, const double r, - const int64_t max_num_neighbors) { + const int64_t max_num_neighbors, + const bool ignore_same_index) { CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_INPUT(x.dim() == 2); @@ -86,7 +88,7 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y, ptr_x.value().data_ptr(), ptr_y.value().data_ptr(), row.data_ptr(), col.data_ptr(), r * r, x.size(0), y.size(0), x.size(1), - ptr_x.value().numel() - 1, max_num_neighbors); + ptr_x.value().numel() - 1, max_num_neighbors, ignore_same_index); }); auto mask = row != -1; diff --git a/csrc/cuda/radius_cuda.h b/csrc/cuda/radius_cuda.h index 4480cbd..7bf8908 100644 --- a/csrc/cuda/radius_cuda.h +++ b/csrc/cuda/radius_cuda.h @@ -5,4 +5,5 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::optional ptr_x, torch::optional ptr_y, double r, - int64_t max_num_neighbors); + int64_t max_num_neighbors, + bool ignore_same_index); diff --git a/csrc/radius.cpp b/csrc/radius.cpp index b79bcc1..27a588a 100644 --- a/csrc/radius.cpp +++ b/csrc/radius.cpp @@ -22,15 +22,16 @@ PyMODINIT_FUNC PyInit__radius_cpu(void) { return NULL; } CLUSTER_API torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::optional ptr_x, torch::optional ptr_y, double r, - int64_t max_num_neighbors, int64_t num_workers) { + int64_t max_num_neighbors, int64_t num_workers, + bool ignore_same_index) { if (x.device().is_cuda()) { #ifdef WITH_CUDA - return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors); + return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors, ignore_same_index); #else AT_ERROR("Not compiled with CUDA support"); #endif } else { - return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors, num_workers); + return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors, num_workers, ignore_same_index); } } diff --git a/test/test_radius.py b/test/test_radius.py index b20b2bf..4289bfc 100644 --- a/test/test_radius.py +++ b/test/test_radius.py @@ -11,6 +11,15 @@ def to_set(edge_index): return set([(i, j) for i, j in edge_index.t().tolist()]) +def to_degree(edge_index): + _, counts = torch.unique(edge_index[1], return_counts=True) + return counts.tolist() + + +def to_batch(nodes): + return [int(i / 4) for i in nodes] + + @pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices)) def test_radius(dtype, device): x = tensor([ @@ -74,6 +83,38 @@ def test_radius_graph(dtype, device): assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)]) + edge_index = radius_graph(x, r=100, flow='source_to_target', + max_num_neighbors=1) + assert set(to_degree(edge_index)) == set([1]) + + x = tensor([ + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + ], dtype, device) + + edge_index = radius_graph(x, r=100, flow='source_to_target', + max_num_neighbors=1) + assert set(to_degree(edge_index)) == set([1]) + + x = tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + ], dtype, device) + batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) + + edge_index = radius_graph(x, r=100, batch=batch_x, flow='source_to_target', + max_num_neighbors=1) + assert set(to_degree(edge_index)) == set([1]) + assert to_batch(edge_index[0]) == batch_x.tolist() + @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) def test_radius_graph_large(dtype, device): diff --git a/torch_cluster/radius.py b/torch_cluster/radius.py index 069824a..92187eb 100644 --- a/torch_cluster/radius.py +++ b/torch_cluster/radius.py @@ -12,6 +12,7 @@ def radius( max_num_neighbors: int = 32, num_workers: int = 1, batch_size: Optional[int] = None, + ignore_same_index: bool = False ) -> torch.Tensor: r"""Finds for each element in :obj:`y` all points in :obj:`x` within distance :obj:`r`. @@ -40,6 +41,9 @@ def radius( :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) + ignore_same_index (bool, optional): If :obj:`True`, each element in + :obj:`y` ignores the point in :obj:`x` with the same index. + (default: :obj:`False`) .. code-block:: python @@ -80,7 +84,8 @@ def radius( ptr_y = torch.bucketize(arange, batch_y) return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r, - max_num_neighbors, num_workers) + max_num_neighbors, num_workers, + ignore_same_index) def radius_graph( @@ -133,15 +138,11 @@ def radius_graph( assert flow in ['source_to_target', 'target_to_source'] edge_index = radius(x, x, r, batch, batch, - max_num_neighbors if loop else max_num_neighbors + 1, - num_workers, batch_size) + max_num_neighbors, + num_workers, batch_size, not loop) if flow == 'source_to_target': row, col = edge_index[1], edge_index[0] else: row, col = edge_index[0], edge_index[1] - if not loop: - mask = row != col - row, col = row[mask], col[mask] - return torch.stack([row, col], dim=0)