diff --git a/CMakeLists.txt b/CMakeLists.txt index e4f8385..6c93254 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,6 +67,7 @@ install(FILES csrc/cpu/fps_cpu.h csrc/cpu/graclus_cpu.h csrc/cpu/grid_cpu.h + csrc/cpu/nearest_cpu.h csrc/cpu/rw_cpu.h csrc/cpu/sampler_cpu.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu) diff --git a/csrc/cpu/nearest_cpu.cpp b/csrc/cpu/nearest_cpu.cpp new file mode 100644 index 0000000..9b5564a --- /dev/null +++ b/csrc/cpu/nearest_cpu.cpp @@ -0,0 +1,86 @@ +#include + +#include "nearest_cpu.h" + +#include + +#include "utils.h" + +torch::Tensor dists(const torch::Tensor &x, const torch::Tensor &y) { + if (x.dtype() != c10::ScalarType::Half && y.dtype() != c10::ScalarType::Half) + return torch::cdist(x, y); + + // Get the sizes of x and y. + int64_t n = x.size(0); + int64_t m = y.size(0); + + // Initialize the distances. + torch::Tensor distances = torch::zeros({n, m}); + + // Calculate the distances. + const int64_t grain_size = 1; + at::parallel_for(0, n, grain_size, [&](int64_t begin, int64_t end) { + const auto calcDistances = [&](torch::Tensor &out, int64_t offset = 0) { + for (int idx = begin; idx 1) { + torch::Tensor distances_chunk = torch::zeros({size, m}); + calcDistances(distances_chunk, begin); + distances.slice(0, begin, end) = distances_chunk; + } else { + calcDistances(distances); + } + }); + + // Return the distances. + return distances; +} + +torch::Tensor nearest_cpu(torch::Tensor x, torch::Tensor y, + torch::Tensor batch_x, torch::Tensor batch_y) { + CHECK_CPU(x); + CHECK_CPU(y); + CHECK_CPU(batch_x); + CHECK_CPU(batch_y); + + batch_x = batch_x.contiguous(); + batch_y = batch_y.contiguous(); + + if (batch_x.size(0) && batch_y.size(0)) { + const auto unique_batch_x = std::get<0>(at::unique_consecutive(batch_x)); + const auto unique_batch_y = std::get<0>(at::unique_consecutive(batch_y)); + if (!torch::equal(unique_batch_x, unique_batch_y)) + throw std::invalid_argument( + "Some batch indices occur in 'batch_x' " + "that do not occur in 'batch_y'"); + + if( (x.dim() != 2 || batch_x.dim() != 1) || + (y.dim() != 2 || batch_y.dim() != 1) || + x.size(0) != batch_x.size(0) || + y.size(0) != batch_y.size(0) ) + throw std::invalid_argument(""); + + const auto min_xy = at::minimum(x.min(), y.min()); + x = at::sub(x, min_xy); + y = at::sub(y, min_xy); + + const auto max_xy = at::maximum(x.max(), y.max()); + x = at::div(x, max_xy); + y = at::div(y, max_xy); + + const double D = x.size(x.dim()-1); + const auto batch_x_view = batch_x.view({-1, 1}).to(x.dtype()); + const auto batch_x_rescaled = x.mul(D); + x = at::cat({x, batch_x_rescaled}, x.dim()-1); + const auto batch_y_view = batch_y.view({-1, 1}).to(y.dtype()); + const auto batch_y_rescaled = y.mul(D); + y = at::cat({y, batch_y_rescaled}, y.dim()-1); + } + + const auto distances = dists(x, y); + return at::argmin(distances, 1); +} diff --git a/csrc/cpu/nearest_cpu.h b/csrc/cpu/nearest_cpu.h new file mode 100644 index 0000000..ce8e03d --- /dev/null +++ b/csrc/cpu/nearest_cpu.h @@ -0,0 +1,6 @@ +#pragma once + +#include "../extensions.h" + +torch::Tensor nearest_cpu(torch::Tensor x, torch::Tensor y, + torch::Tensor batch_x, torch::Tensor batch_y); diff --git a/csrc/nearest.cpp b/csrc/nearest.cpp index 83eb41e..569e0fd 100644 --- a/csrc/nearest.cpp +++ b/csrc/nearest.cpp @@ -5,6 +5,8 @@ #include "extensions.h" +#include "cpu/nearest_cpu.h" + #ifdef WITH_CUDA #include "cuda/nearest_cuda.h" #endif @@ -19,16 +21,16 @@ PyMODINIT_FUNC PyInit__nearest_cpu(void) { return NULL; } #endif #endif -CLUSTER_API torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, - torch::Tensor ptr_y) { +CLUSTER_API torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor batch_ptr_x, + torch::Tensor batch_ptr_y) { if (x.device().is_cuda()) { #ifdef WITH_CUDA - return nearest_cuda(x, y, ptr_x, ptr_y); + return nearest_cuda(x, y, batch_ptr_x, batch_ptr_y); #else AT_ERROR("Not compiled with CUDA support"); #endif } else { - AT_ERROR("No CPU version supported"); + return nearest_cpu(x, y, batch_ptr_x, batch_ptr_y); } } diff --git a/setup.py b/setup.py index 0b8c798..db74a11 100644 --- a/setup.py +++ b/setup.py @@ -106,10 +106,6 @@ def get_extensions(): return extensions -install_requires = [ - 'scipy', -] - test_requires = [ 'pytest', 'pytest-cov', @@ -136,7 +132,6 @@ def get_extensions(): 'cluster-algorithms', ], python_requires='>=3.7', - install_requires=install_requires, extras_require={ 'test': test_requires, }, diff --git a/test/test_nearest.py b/test/test_nearest.py index 582818e..2aca230 100644 --- a/test/test_nearest.py +++ b/test/test_nearest.py @@ -44,6 +44,11 @@ def test_nearest(dtype, device): with pytest.raises(ValueError): nearest(x, y, batch_x, batch_y=None) + # Invalid input: instance 1 only in batch_y (implicitly as batch_x=None) + batch_y = tensor([0, 0, 1, 1], torch.long, device) + with pytest.raises(ValueError): + nearest(x, y, batch_x=None, batch_y=batch_y) + # Invalid input: instance 2 only in batch_x # (i.e.instance in the middle missing) batch_x = tensor([0, 0, 1, 1, 2, 2, 3, 3], torch.long, device) diff --git a/torch_cluster/nearest.py b/torch_cluster/nearest.py index 1ba4db6..a35a258 100644 --- a/torch_cluster/nearest.py +++ b/torch_cluster/nearest.py @@ -1,6 +1,5 @@ from typing import Optional -import scipy.cluster import torch @@ -86,39 +85,13 @@ def nearest( return torch.ops.torch_cluster.nearest(x, y, ptr_x, ptr_y) else: - - if batch_x is None and batch_y is not None: - batch_x = x.new_zeros(x.size(0), dtype=torch.long) - if batch_y is None and batch_x is not None: + if batch_x is None: + if batch_y is None: + batch_x = torch.tensor([]) + batch_y = torch.tensor([]) + else: + batch_x = x.new_zeros(x.size(0), dtype=torch.long) + elif batch_y is None: batch_y = y.new_zeros(y.size(0), dtype=torch.long) - # Translate and rescale x and y to [0, 1]. - if batch_x is not None and batch_y is not None: - # If an instance in `batch_x` is non-empty, it must be non-empty in - # `batch_y `as well: - unique_batch_x = batch_x.unique_consecutive() - unique_batch_y = batch_y.unique_consecutive() - if not torch.equal(unique_batch_x, unique_batch_y): - raise ValueError("Some batch indices occur in 'batch_x' " - "that do not occur in 'batch_y'") - - assert x.dim() == 2 and batch_x.dim() == 1 - assert y.dim() == 2 and batch_y.dim() == 1 - assert x.size(0) == batch_x.size(0) - assert y.size(0) == batch_y.size(0) - - min_xy = min(x.min().item(), y.min().item()) - x, y = x - min_xy, y - min_xy - - max_xy = max(x.max().item(), y.max().item()) - x.div_(max_xy) - y.div_(max_xy) - - # Concat batch/features to ensure no cross-links between examples. - D = x.size(-1) - x = torch.cat([x, 2 * D * batch_x.view(-1, 1).to(x.dtype)], -1) - y = torch.cat([y, 2 * D * batch_y.view(-1, 1).to(y.dtype)], -1) - - return torch.from_numpy( - scipy.cluster.vq.vq(x.detach().cpu(), - y.detach().cpu())[0]).to(torch.long) + return torch.ops.torch_cluster.nearest(x, y, batch_x, batch_y)