In [4]:
import pandas as pd
import numpy as np
import torch

In [95]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C]
    """
    raw_size = idx.size()
    idx = idx.reshape(raw_size[0], -1)
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
    return res.reshape(*raw_size, -1)


In [13]:
def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx


In [113]:
radius = 20
nsample = 2
B = 1
S = 3
N = 10

xyz = torch.rand(B, N, 3) * 1.0
new_xyz = torch.rand(B, S, 3) * 1.0

In [114]:
xyz

tensor([[[8.7886e-01, 9.4075e-01, 8.3077e-01],
         [5.6785e-01, 7.9251e-01, 6.8068e-01],
         [4.1897e-01, 9.3082e-01, 2.4444e-01],
         [7.0621e-01, 5.6720e-01, 5.0883e-01],
         [5.5471e-01, 4.9437e-01, 7.6719e-01],
         [3.6126e-01, 2.3615e-04, 7.8746e-01],
         [8.5152e-02, 8.4506e-01, 7.4679e-01],
         [1.3982e-01, 9.7048e-01, 5.9605e-01],
         [6.7859e-01, 2.4837e-02, 9.4025e-01],
         [5.7001e-02, 4.2502e-01, 8.4366e-01]]])

In [115]:
new_xyz

tensor([[[0.9250, 0.7176, 0.8224],
         [0.6307, 0.5977, 0.1762],
         [0.5833, 0.9340, 0.2189]]])

In [116]:
idx = query_ball_point(radius**0.5, nsample, xyz, new_xyz)

In [117]:
grouped_points = index_points(xyz, idx)

In [118]:
grouped_points

tensor([[[[0.8789, 0.9407, 0.8308],
          [0.5679, 0.7925, 0.6807]],

         [[0.8789, 0.9407, 0.8308],
          [0.5679, 0.7925, 0.6807]],

         [[0.8789, 0.9407, 0.8308],
          [0.5679, 0.7925, 0.6807]]]])

In [119]:
local_feat = grouped_points - new_xyz.view(B, S, 1, -1)

In [120]:
local_feat.shape

torch.Size([1, 3, 2, 3])

In [121]:
local_feat

tensor([[[[-0.0462,  0.2231,  0.0084],
          [-0.3572,  0.0749, -0.1417]],

         [[ 0.2482,  0.3430,  0.6545],
          [-0.0628,  0.1948,  0.5045]],

         [[ 0.2956,  0.0067,  0.6119],
          [-0.0154, -0.1415,  0.4618]]]])

In [122]:
torch.cat((local_feat, local_feat), dim=-1)

tensor([[[[-0.0462,  0.2231,  0.0084, -0.0462,  0.2231,  0.0084],
          [-0.3572,  0.0749, -0.1417, -0.3572,  0.0749, -0.1417]],

         [[ 0.2482,  0.3430,  0.6545,  0.2482,  0.3430,  0.6545],
          [-0.0628,  0.1948,  0.5045, -0.0628,  0.1948,  0.5045]],

         [[ 0.2956,  0.0067,  0.6119,  0.2956,  0.0067,  0.6119],
          [-0.0154, -0.1415,  0.4618, -0.0154, -0.1415,  0.4618]]]])

In [124]:
torch.empty(B, S, nsample)

tensor([[[2.3694e-38, 2.3694e-38],
         [2.3694e-38, 2.3694e-38],
         [3.6013e-43, 0.0000e+00]]])

In [125]:
local_feat

tensor([[[[-0.0462,  0.2231,  0.0084],
          [-0.3572,  0.0749, -0.1417]],

         [[ 0.2482,  0.3430,  0.6545],
          [-0.0628,  0.1948,  0.5045]],

         [[ 0.2956,  0.0067,  0.6119],
          [-0.0154, -0.1415,  0.4618]]]])