From 7bea2950eebf1ec738ec81e512b19f3225584cc7 Mon Sep 17 00:00:00 2001 From: Tristan Heywood Date: Sat, 4 Jan 2020 14:05:20 +0100 Subject: [PATCH 1/3] ignore ide and build files --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index bf9e290..1f7023a 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,8 @@ *.app build -*.pyc \ No newline at end of file +*.pyc + +.vscode/ +dist/ +torch_points.egg-info/ \ No newline at end of file From 1e027c1deb45a49e61d2012978e0dd8358273508 Mon Sep 17 00:00:00 2001 From: Tristan Heywood Date: Sun, 5 Jan 2020 14:37:17 +0100 Subject: [PATCH 2/3] add cpp group_points operation + test --- cpu/include/group_points.h | 5 +++ cpu/include/utils.h | 6 ++++ cpu/src/bindings.cpp | 5 +++ cpu/src/group_points.cpp | 30 ++++++++++++++++++ setup.py | 15 ++++++++- test/test_grouping.py | 56 ++++++++++++++++++++++++++++++++++ torch_points/torchpoints.py | 61 ++++++++++++++++++++++++++++--------- 7 files changed, 162 insertions(+), 16 deletions(-) create mode 100644 cpu/include/group_points.h create mode 100644 cpu/include/utils.h create mode 100644 cpu/src/bindings.cpp create mode 100644 cpu/src/group_points.cpp create mode 100644 test/test_grouping.py diff --git a/cpu/include/group_points.h b/cpu/include/group_points.h new file mode 100644 index 0000000..53f89f9 --- /dev/null +++ b/cpu/include/group_points.h @@ -0,0 +1,5 @@ +#pragma once +#include + +at::Tensor group_points(at::Tensor points, at::Tensor idx); +at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); \ No newline at end of file diff --git a/cpu/include/utils.h b/cpu/include/utils.h new file mode 100644 index 0000000..ab590e4 --- /dev/null +++ b/cpu/include/utils.h @@ -0,0 +1,6 @@ +#pragma once +#include + +#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be a CPU tensor") + +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be a contiguous tensor") \ No newline at end of file diff --git a/cpu/src/bindings.cpp b/cpu/src/bindings.cpp new file mode 100644 index 0000000..d9ac5a0 --- /dev/null +++ b/cpu/src/bindings.cpp @@ -0,0 +1,5 @@ +#include "group_points.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("group_points", &group_points); +} \ No newline at end of file diff --git a/cpu/src/group_points.cpp b/cpu/src/group_points.cpp new file mode 100644 index 0000000..ab0ebda --- /dev/null +++ b/cpu/src/group_points.cpp @@ -0,0 +1,30 @@ +#include "group_points.h" +#include "utils.h" + +// input: points(b, c, n) idx(b, npoints, nsample) +// output: out(b, c, npoints, nsample) +at::Tensor group_points(at::Tensor points, at::Tensor idx) { + CHECK_CPU(points); + CHECK_CPU(idx); + + at::Tensor output = torch::zeros( + {points.size(0), points.size(1), idx.size(1), idx.size(2)}, + at::device(points.device()).dtype(at::ScalarType::Float) + ); + + for (int batch_index = 0; batch_index < output.size(0); batch_index++) { + for (int feat_index = 0; feat_index < output.size(1); feat_index++) { + for (int point_index = 0; point_index < output.size(2); point_index++) { + for (int sample_index = 0; sample_index < output.size(3); sample_index++) { + output[batch_index][feat_index][point_index][sample_index] + = points[batch_index][feat_index][ + idx[batch_index][point_index][sample_index] + ]; + } + } + } + } + + return output; + +} \ No newline at end of file diff --git a/setup.py b/setup.py index 4ac3127..71aa73b 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ from setuptools import setup, find_packages -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, CppExtension import glob ext_src_root = "cuda" @@ -20,6 +20,19 @@ ) ) +cpu_ext_src_root = "cpu" +cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root)) + +ext_modules.append( + CppExtension( + name="torch_points.points_cpu", + sources=cpu_ext_sources, + extra_compile_args={ + "cxx": ["-O2", "-I{}".format("{}/include".format(cpu_ext_src_root))], + }, + ) +) + setup( name="torch_points", version="0.1.2", diff --git a/test/test_grouping.py b/test/test_grouping.py new file mode 100644 index 0000000..6c36739 --- /dev/null +++ b/test/test_grouping.py @@ -0,0 +1,56 @@ +import unittest +import torch +import numpy as np +import numpy.testing as npt +from torch_points import grouping_operation + +class TestGroup(unittest.TestCase): + + # input: points(b, c, n) idx(b, npoints, nsample) + # output: out(b, c, npoints, nsample) + def test_simple(self): + features = torch.tensor([ + [[0, 10, 0], [1, 11, 0], [2, 12, 0]], + [ + [100, 110, 120], # x-coordinates + [101, 111, 121], # y-coordinates + [102, 112, 122], # z-coordinates + ] + ]) + idx = torch.tensor([ + [[1, 0], [0, 0]], + [[0, 1], [1, 2]] + ]) + + expected = np.array([ + [ + [[10, 0], [0, 0]], + [[11, 1], [1, 1]], + [[12, 2], [2, 2]] + ], + [ # 2nd batch + [ # x-coordinates + [100, 110], #x-coordinates of samples for point 0 + [110, 120], #x-coordinates of samples for point 1 + ], + [[101, 111], [111, 121]], # y-coordinates + [[102, 112], [112, 122]], # z-coordinates + ] + ]) + + cpu_output = grouping_operation(features, idx).detach().cpu().numpy() + + npt.assert_array_equal(expected, cpu_output) + + if torch.cuda.is_available(): + npt.assert_array_equal( + grouping_operation( + features.cuda(), + idx.cuda() + ).detach().cpu().numpy(), expected) + + + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/torch_points/torchpoints.py b/torch_points/torchpoints.py index d493408..7a6cd69 100644 --- a/torch_points/torchpoints.py +++ b/torch_points/torchpoints.py @@ -3,13 +3,19 @@ import torch.nn as nn import sys -import torch_points.points_cuda as tpcuda +import torch_points.points_cpu as tpcpu + +if torch.cuda.is_available(): + import torch_points.points_cuda as tpcuda class FurthestPointSampling(Function): @staticmethod def forward(ctx, xyz, npoint): - return tpcuda.furthest_point_sampling(xyz, npoint) + if xyz.is_cuda: + return tpcuda.furthest_point_sampling(xyz, npoint) + else: + raise NotImplementedError @staticmethod def backward(xyz, a=None): @@ -45,14 +51,20 @@ def forward(ctx, features, idx): ctx.for_backwards = (idx, C, N) - return tpcuda.gather_points(features, idx) + if features.is_cuda: + return tpcuda.gather_points(features, idx) + else: + return tpcpu.gather_points(features, idx) @staticmethod def backward(ctx, grad_out): idx, C, N = ctx.for_backwards - grad_features = tpcuda.gather_points_grad(grad_out.contiguous(), idx, N) - return grad_features, None + if grad_out.is_cuda: + grad_features = tpcuda.gather_points_grad(grad_out.contiguous(), idx, N) + return grad_features, None + else: + raise NotImplementedError def gather_operation(features, idx): @@ -64,12 +76,12 @@ def gather_operation(features, idx): (B, C, N) tensor idx : torch.Tensor - (B, npoint) tensor of the features to gather + (B, npoint, nsample) tensor of the features to gather Returns ------- torch.Tensor - (B, C, npoint) tensor + (B, C, npoint, nsample) tensor """ return GatherOperation.apply(features, idx) @@ -78,7 +90,11 @@ class ThreeNN(Function): @staticmethod def forward(ctx, unknown, known): # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] - dist2, idx = tpcuda.three_nn(unknown, known) + + if unknown.is_cuda: + dist2, idx = tpcuda.three_nn(unknown, known) + else: + raise NotImplementedError return torch.sqrt(dist2), idx @@ -116,7 +132,10 @@ def forward(ctx, features, idx, weight): ctx.three_interpolate_for_backward = (idx, weight, m) - return tpcuda.three_interpolate(features, idx, weight) + if features.is_cuda: + return tpcuda.three_interpolate(features, idx, weight) + else: + raise NotImplementedError @staticmethod def backward(ctx, grad_out): @@ -138,9 +157,12 @@ def backward(ctx, grad_out): """ idx, weight, m = ctx.three_interpolate_for_backward - grad_features = tpcuda.three_interpolate_grad( - grad_out.contiguous(), idx, weight, m - ) + if grad_out.is_cuda: + grad_features = tpcuda.three_interpolate_grad( + grad_out.contiguous(), idx, weight, m + ) + else: + raise NotImplementedError return grad_features, None, None @@ -174,7 +196,10 @@ def forward(ctx, features, idx): ctx.for_backwards = (idx, N) - return tpcuda.group_points(features, idx) + if features.is_cuda: + return tpcuda.group_points(features, idx) + else: + return tpcpu.group_points(features, idx) @staticmethod def backward(ctx, grad_out): @@ -194,7 +219,10 @@ def backward(ctx, grad_out): """ idx, N = ctx.for_backwards - grad_features = tpcuda.group_points_grad(grad_out.contiguous(), idx, N) + if grad_out.is_cuda: + grad_features = tpcuda.group_points_grad(grad_out.contiguous(), idx, N) + else: + raise NotImplementedError return grad_features, None @@ -220,7 +248,10 @@ class BallQuery(Function): @staticmethod def forward(ctx, radius, nsample, xyz, new_xyz): # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor - return tpcuda.ball_query(new_xyz, xyz, radius, nsample) + if new_xyz.is_cuda: + return tpcuda.ball_query(new_xyz, xyz, radius, nsample) + else: + raise NotImplementedError @staticmethod def backward(ctx, a=None): From 84b28cc5b43f57bf1e859ffdf5acecfdc78096fc Mon Sep 17 00:00:00 2001 From: Tristan Heywood Date: Sun, 5 Jan 2020 14:45:21 +0100 Subject: [PATCH 3/3] clean up for pr --- cpu/include/group_points.h | 2 +- cpu/include/utils.h | 2 +- cpu/src/bindings.cpp | 2 +- cpu/src/group_points.cpp | 3 +-- test/test_grouping.py | 2 -- 5 files changed, 4 insertions(+), 7 deletions(-) diff --git a/cpu/include/group_points.h b/cpu/include/group_points.h index 53f89f9..cb83f95 100644 --- a/cpu/include/group_points.h +++ b/cpu/include/group_points.h @@ -2,4 +2,4 @@ #include at::Tensor group_points(at::Tensor points, at::Tensor idx); -at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); \ No newline at end of file +at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); diff --git a/cpu/include/utils.h b/cpu/include/utils.h index ab590e4..9b26ea2 100644 --- a/cpu/include/utils.h +++ b/cpu/include/utils.h @@ -3,4 +3,4 @@ #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be a CPU tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be a contiguous tensor") \ No newline at end of file +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be a contiguous tensor") diff --git a/cpu/src/bindings.cpp b/cpu/src/bindings.cpp index d9ac5a0..e026580 100644 --- a/cpu/src/bindings.cpp +++ b/cpu/src/bindings.cpp @@ -2,4 +2,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("group_points", &group_points); -} \ No newline at end of file +} diff --git a/cpu/src/group_points.cpp b/cpu/src/group_points.cpp index ab0ebda..8c086cd 100644 --- a/cpu/src/group_points.cpp +++ b/cpu/src/group_points.cpp @@ -26,5 +26,4 @@ at::Tensor group_points(at::Tensor points, at::Tensor idx) { } return output; - -} \ No newline at end of file +} diff --git a/test/test_grouping.py b/test/test_grouping.py index 6c36739..8f575c5 100644 --- a/test/test_grouping.py +++ b/test/test_grouping.py @@ -50,7 +50,5 @@ def test_simple(self): ).detach().cpu().numpy(), expected) - - if __name__ == '__main__': unittest.main() \ No newline at end of file