Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,8 @@
*.app

build
*.pyc
*.pyc

.vscode/
dist/
torch_points.egg-info/
5 changes: 5 additions & 0 deletions cpu/include/group_points.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once
#include <torch/extension.h>

at::Tensor group_points(at::Tensor points, at::Tensor idx);
at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 changes: 6 additions & 0 deletions cpu/include/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once
#include <torch/extension.h>

#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be a CPU tensor")

#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be a contiguous tensor")
5 changes: 5 additions & 0 deletions cpu/src/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "group_points.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("group_points", &group_points);
}
29 changes: 29 additions & 0 deletions cpu/src/group_points.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "group_points.h"
#include "utils.h"

// input: points(b, c, n) idx(b, npoints, nsample)
// output: out(b, c, npoints, nsample)
at::Tensor group_points(at::Tensor points, at::Tensor idx) {
CHECK_CPU(points);
CHECK_CPU(idx);

at::Tensor output = torch::zeros(
{points.size(0), points.size(1), idx.size(1), idx.size(2)},
at::device(points.device()).dtype(at::ScalarType::Float)
);

for (int batch_index = 0; batch_index < output.size(0); batch_index++) {
for (int feat_index = 0; feat_index < output.size(1); feat_index++) {
for (int point_index = 0; point_index < output.size(2); point_index++) {
for (int sample_index = 0; sample_index < output.size(3); sample_index++) {
output[batch_index][feat_index][point_index][sample_index]
= points[batch_index][feat_index][
idx[batch_index][point_index][sample_index]
];
}
}
}
}

return output;
}
15 changes: 14 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, CppExtension
import glob

ext_src_root = "cuda"
Expand All @@ -20,6 +20,19 @@
)
)

cpu_ext_src_root = "cpu"
cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root))

ext_modules.append(
CppExtension(
name="torch_points.points_cpu",
sources=cpu_ext_sources,
extra_compile_args={
"cxx": ["-O2", "-I{}".format("{}/include".format(cpu_ext_src_root))],
},
)
)

setup(
name="torch_points",
version="0.1.2",
Expand Down
54 changes: 54 additions & 0 deletions test/test_grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
import torch
import numpy as np
import numpy.testing as npt
from torch_points import grouping_operation

class TestGroup(unittest.TestCase):

# input: points(b, c, n) idx(b, npoints, nsample)
# output: out(b, c, npoints, nsample)
def test_simple(self):
features = torch.tensor([
[[0, 10, 0], [1, 11, 0], [2, 12, 0]],
[
[100, 110, 120], # x-coordinates
[101, 111, 121], # y-coordinates
[102, 112, 122], # z-coordinates
]
])
idx = torch.tensor([
[[1, 0], [0, 0]],
[[0, 1], [1, 2]]
])

expected = np.array([
[
[[10, 0], [0, 0]],
[[11, 1], [1, 1]],
[[12, 2], [2, 2]]
],
[ # 2nd batch
[ # x-coordinates
[100, 110], #x-coordinates of samples for point 0
[110, 120], #x-coordinates of samples for point 1
],
[[101, 111], [111, 121]], # y-coordinates
[[102, 112], [112, 122]], # z-coordinates
]
])

cpu_output = grouping_operation(features, idx).detach().cpu().numpy()

npt.assert_array_equal(expected, cpu_output)

if torch.cuda.is_available():
npt.assert_array_equal(
grouping_operation(
features.cuda(),
idx.cuda()
).detach().cpu().numpy(), expected)


if __name__ == '__main__':
unittest.main()
61 changes: 46 additions & 15 deletions torch_points/torchpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import torch.nn as nn
import sys

import torch_points.points_cuda as tpcuda
import torch_points.points_cpu as tpcpu

if torch.cuda.is_available():
import torch_points.points_cuda as tpcuda


class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz, npoint):
return tpcuda.furthest_point_sampling(xyz, npoint)
if xyz.is_cuda:
return tpcuda.furthest_point_sampling(xyz, npoint)
else:
raise NotImplementedError

@staticmethod
def backward(xyz, a=None):
Expand Down Expand Up @@ -45,14 +51,20 @@ def forward(ctx, features, idx):

ctx.for_backwards = (idx, C, N)

return tpcuda.gather_points(features, idx)
if features.is_cuda:
return tpcuda.gather_points(features, idx)
else:
return tpcpu.gather_points(features, idx)

@staticmethod
def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards

grad_features = tpcuda.gather_points_grad(grad_out.contiguous(), idx, N)
return grad_features, None
if grad_out.is_cuda:
grad_features = tpcuda.gather_points_grad(grad_out.contiguous(), idx, N)
return grad_features, None
else:
raise NotImplementedError


def gather_operation(features, idx):
Expand All @@ -64,12 +76,12 @@ def gather_operation(features, idx):
(B, C, N) tensor

idx : torch.Tensor
(B, npoint) tensor of the features to gather
(B, npoint, nsample) tensor of the features to gather

Returns
-------
torch.Tensor
(B, C, npoint) tensor
(B, C, npoint, nsample) tensor
"""
return GatherOperation.apply(features, idx)

Expand All @@ -78,7 +90,11 @@ class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown, known):
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
dist2, idx = tpcuda.three_nn(unknown, known)

if unknown.is_cuda:
dist2, idx = tpcuda.three_nn(unknown, known)
else:
raise NotImplementedError

return torch.sqrt(dist2), idx

Expand Down Expand Up @@ -116,7 +132,10 @@ def forward(ctx, features, idx, weight):

ctx.three_interpolate_for_backward = (idx, weight, m)

return tpcuda.three_interpolate(features, idx, weight)
if features.is_cuda:
return tpcuda.three_interpolate(features, idx, weight)
else:
raise NotImplementedError

@staticmethod
def backward(ctx, grad_out):
Expand All @@ -138,9 +157,12 @@ def backward(ctx, grad_out):
"""
idx, weight, m = ctx.three_interpolate_for_backward

grad_features = tpcuda.three_interpolate_grad(
grad_out.contiguous(), idx, weight, m
)
if grad_out.is_cuda:
grad_features = tpcuda.three_interpolate_grad(
grad_out.contiguous(), idx, weight, m
)
else:
raise NotImplementedError

return grad_features, None, None

Expand Down Expand Up @@ -174,7 +196,10 @@ def forward(ctx, features, idx):

ctx.for_backwards = (idx, N)

return tpcuda.group_points(features, idx)
if features.is_cuda:
return tpcuda.group_points(features, idx)
else:
return tpcpu.group_points(features, idx)

@staticmethod
def backward(ctx, grad_out):
Expand All @@ -194,7 +219,10 @@ def backward(ctx, grad_out):
"""
idx, N = ctx.for_backwards

grad_features = tpcuda.group_points_grad(grad_out.contiguous(), idx, N)
if grad_out.is_cuda:
grad_features = tpcuda.group_points_grad(grad_out.contiguous(), idx, N)
else:
raise NotImplementedError

return grad_features, None

Expand All @@ -220,7 +248,10 @@ class BallQuery(Function):
@staticmethod
def forward(ctx, radius, nsample, xyz, new_xyz):
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
return tpcuda.ball_query(new_xyz, xyz, radius, nsample)
if new_xyz.is_cuda:
return tpcuda.ball_query(new_xyz, xyz, radius, nsample)
else:
raise NotImplementedError

@staticmethod
def backward(ctx, a=None):
Expand Down