Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ install(FILES
csrc/cpu/fps_cpu.h
csrc/cpu/graclus_cpu.h
csrc/cpu/grid_cpu.h
csrc/cpu/nearest_cpu.h
csrc/cpu/rw_cpu.h
csrc/cpu/sampler_cpu.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu)
Expand Down
86 changes: 86 additions & 0 deletions csrc/cpu/nearest_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include <algorithm>

#include "nearest_cpu.h"

#include <ATen/Parallel.h>

#include "utils.h"

torch::Tensor dists(const torch::Tensor &x, const torch::Tensor &y) {
if (x.dtype() != c10::ScalarType::Half && y.dtype() != c10::ScalarType::Half)
return torch::cdist(x, y);

// Get the sizes of x and y.
int64_t n = x.size(0);
int64_t m = y.size(0);

// Initialize the distances.
torch::Tensor distances = torch::zeros({n, m});

// Calculate the distances.
const int64_t grain_size = 1;
at::parallel_for(0, n, grain_size, [&](int64_t begin, int64_t end) {
const auto calcDistances = [&](torch::Tensor &out, int64_t offset = 0) {
for (int idx = begin; idx<end; idx++) {
const auto sqr_dist = torch::pow(x[idx] - y, 2).sum(1);
out.index_put_({idx-offset}, sqr_dist);
}
};
const auto size = end - begin;
if (size > 1) {
torch::Tensor distances_chunk = torch::zeros({size, m});
calcDistances(distances_chunk, begin);
distances.slice(0, begin, end) = distances_chunk;
} else {
calcDistances(distances);
}
});

// Return the distances.
return distances;
}

torch::Tensor nearest_cpu(torch::Tensor x, torch::Tensor y,
torch::Tensor batch_x, torch::Tensor batch_y) {
CHECK_CPU(x);
CHECK_CPU(y);
CHECK_CPU(batch_x);
CHECK_CPU(batch_y);

batch_x = batch_x.contiguous();
batch_y = batch_y.contiguous();

if (batch_x.size(0) && batch_y.size(0)) {
const auto unique_batch_x = std::get<0>(at::unique_consecutive(batch_x));
const auto unique_batch_y = std::get<0>(at::unique_consecutive(batch_y));
if (!torch::equal(unique_batch_x, unique_batch_y))
throw std::invalid_argument(
"Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'");

if( (x.dim() != 2 || batch_x.dim() != 1) ||
(y.dim() != 2 || batch_y.dim() != 1) ||
x.size(0) != batch_x.size(0) ||
y.size(0) != batch_y.size(0) )
throw std::invalid_argument("");

const auto min_xy = at::minimum(x.min(), y.min());
x = at::sub(x, min_xy);
y = at::sub(y, min_xy);

const auto max_xy = at::maximum(x.max(), y.max());
x = at::div(x, max_xy);
y = at::div(y, max_xy);

const double D = x.size(x.dim()-1);
const auto batch_x_view = batch_x.view({-1, 1}).to(x.dtype());
const auto batch_x_rescaled = x.mul(D);
x = at::cat({x, batch_x_rescaled}, x.dim()-1);
const auto batch_y_view = batch_y.view({-1, 1}).to(y.dtype());
const auto batch_y_rescaled = y.mul(D);
y = at::cat({y, batch_y_rescaled}, y.dim()-1);
}

const auto distances = dists(x, y);
return at::argmin(distances, 1);
}
6 changes: 6 additions & 0 deletions csrc/cpu/nearest_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

#include "../extensions.h"

torch::Tensor nearest_cpu(torch::Tensor x, torch::Tensor y,
torch::Tensor batch_x, torch::Tensor batch_y);
10 changes: 6 additions & 4 deletions csrc/nearest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include "extensions.h"

#include "cpu/nearest_cpu.h"

#ifdef WITH_CUDA
#include "cuda/nearest_cuda.h"
#endif
Expand All @@ -19,16 +21,16 @@ PyMODINIT_FUNC PyInit__nearest_cpu(void) { return NULL; }
#endif
#endif

CLUSTER_API torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y) {
CLUSTER_API torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor batch_ptr_x,
torch::Tensor batch_ptr_y) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return nearest_cuda(x, y, ptr_x, ptr_y);
return nearest_cuda(x, y, batch_ptr_x, batch_ptr_y);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
AT_ERROR("No CPU version supported");
return nearest_cpu(x, y, batch_ptr_x, batch_ptr_y);
}
}

Expand Down
5 changes: 0 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def get_extensions():
return extensions


install_requires = [
'scipy',
]

test_requires = [
'pytest',
'pytest-cov',
Expand All @@ -136,7 +132,6 @@ def get_extensions():
'cluster-algorithms',
],
python_requires='>=3.7',
install_requires=install_requires,
extras_require={
'test': test_requires,
},
Expand Down
5 changes: 5 additions & 0 deletions test/test_nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def test_nearest(dtype, device):
with pytest.raises(ValueError):
nearest(x, y, batch_x, batch_y=None)

# Invalid input: instance 1 only in batch_y (implicitly as batch_x=None)
batch_y = tensor([0, 0, 1, 1], torch.long, device)
with pytest.raises(ValueError):
nearest(x, y, batch_x=None, batch_y=batch_y)

# Invalid input: instance 2 only in batch_x
# (i.e.instance in the middle missing)
batch_x = tensor([0, 0, 1, 1, 2, 2, 3, 3], torch.long, device)
Expand Down
43 changes: 8 additions & 35 deletions torch_cluster/nearest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional

import scipy.cluster
import torch


Expand Down Expand Up @@ -86,39 +85,13 @@ def nearest(
return torch.ops.torch_cluster.nearest(x, y, ptr_x, ptr_y)

else:

if batch_x is None and batch_y is not None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None and batch_x is not None:
if batch_x is None:
if batch_y is None:
batch_x = torch.tensor([])
batch_y = torch.tensor([])
else:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
elif batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)

# Translate and rescale x and y to [0, 1].
if batch_x is not None and batch_y is not None:
# If an instance in `batch_x` is non-empty, it must be non-empty in
# `batch_y `as well:
unique_batch_x = batch_x.unique_consecutive()
unique_batch_y = batch_y.unique_consecutive()
if not torch.equal(unique_batch_x, unique_batch_y):
raise ValueError("Some batch indices occur in 'batch_x' "
"that do not occur in 'batch_y'")

assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)

min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy

max_xy = max(x.max().item(), y.max().item())
x.div_(max_xy)
y.div_(max_xy)

# Concat batch/features to ensure no cross-links between examples.
D = x.size(-1)
x = torch.cat([x, 2 * D * batch_x.view(-1, 1).to(x.dtype)], -1)
y = torch.cat([y, 2 * D * batch_y.view(-1, 1).to(y.dtype)], -1)

return torch.from_numpy(
scipy.cluster.vq.vq(x.detach().cpu(),
y.detach().cpu())[0]).to(torch.long)
return torch.ops.torch_cluster.nearest(x, y, batch_x, batch_y)