In [7]:
import open3d as o3d
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from pointnet2_ops import pointnet2_utils
from knn_cuda import KNN

In [8]:
print("Load a ply point cloud, print it, and render it")
path = "../dataset/Dataset/Barn_is/Barn/Barn01.ply"
pcd = o3d.io.read_point_cloud(path)
print(pcd)

voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd,
                                                        voxel_size=0.5)
voxels = voxel_grid.get_voxels()  # returns list of voxels
indices = np.stack(list(vx.grid_index for vx in voxels))
colors = np.stack(list(vx.color for vx in voxels))
# print(indices[0:10])
# print(voxel_grid)
o3d.visualization.draw_geometries([voxel_grid])

voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd,
                                                    voxel_size=0.2)
voxels = voxel_grid.get_voxels()  # returns list of voxels
indices = np.stack(list(vx.grid_index for vx in voxels))
colors = np.stack(list(vx.color for vx in voxels))
# print(indices[0:10])
# print(voxel_grid)
o3d.visualization.draw_geometries([voxel_grid])


Load a ply point cloud, print it, and render it
PointCloud with 1642571 points.


In [9]:
class Tokenizer (nn.Module):
    def __init__(self, num_group, group_size):
        super().__init__()
        self.num_group = num_group
        self.group_size = group_size
        self.knn = KNN(k=self.group_size, transpose_mode=True)

    def forward(self, xyz):
            '''
                input: B N 3
                ---------------------------
                output: B G M 3
                center : B G 3
            '''
            batch_size, num_points, _ = xyz.shape
            # fps the centers out
            center = self.fps(xyz, self.num_group) # B G 3
            # knn to get the neighborhood
            _, idx = self.knn(xyz, center) # B G M
            assert idx.size(1) == self.num_group
            assert idx.size(2) == self.group_size
            idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
            idx = idx + idx_base
            idx = idx.view(-1)
            neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
            neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
            # normalize
            neighborhood = neighborhood - center.unsqueeze(2)
            return neighborhood, center
    
    def fps(data, number):
        '''
            data B N 3
            number int
        '''
        fps_idx = pointnet2_utils.furthest_point_sample(data, number) 
        fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
        return fps_data

In [10]:
tokenizer = Tokenizer(32,64)

In [None]:
tokenizer.forward()