In [1]:
import numpy as np
import sys, os
import matplotlib.image as mpimg
import argparse
import ast

# file_path = os.path.dirname(os.path.abspath(__file__)) # this is for the .py script but does not work in a notebook
file_path = os.path.dirname(os.path.abspath(''))
sys.path.append(file_path)
# sys.path.append(os.path.join(file_path, "grid-graph/python/bin"))
# sys.path.append(os.path.join(file_path, "parallel-cut-pursuit/python/wrappers"))

# Data loading

In [28]:
from time import time
import glob
from torch_geometric.data import Data
from superpoint_transformer.datasets.kitti360 import read_kitti360_window
from superpoint_transformer.datasets.kitti360_config import KITTI360_NUM_CLASSES

i_window = 0
all_filepaths = sorted(glob.glob('/media/drobert-admin/DATA2/datasets/kitti360/shared/data_3d_semantics/*/static/*.ply'))
filepath = all_filepaths[i_window]

start = time()
data = read_kitti360_window(filepath, semantic=True, instance=False, remap=True)
print(f'Loading data {i_window+1}/{len(all_filepaths)}: {time() - start:0.3f}s')
print(f'Number of loaded points: {data.num_nodes} ({data.num_nodes // 10**6:0.2f}M)')

Loading data 1/342: 0.086s
Number of loaded points: 3201318 (3.00M)


# Voxelization

In [6]:
voxel = 0.05

### Pytorch Geometric

In [7]:
from torch_geometric.nn.pool import voxel_grid

start = time()
data_sub = voxel_grid(data.pos, size=0.1)
print(f'Data  voxelization at {voxel}m: {time() - start:0.3f}s')
print(f'Number of sampled points: {data_sub.shape[0]} ({data_sub.shape[0] // 10**6:0.2f}M, {100 * data_sub.shape[0] / data.num_nodes:0.1f}%)')

Data  voxelization at 0.05m: 0.157s
Number of loaded points: 3201318 (3.00M, 1.0%)


### TorchPoints3D

In [44]:
import torch
import re
from torch_geometric.nn.pool import voxel_grid
from torch_cluster import grid_cluster
from torch_scatter import scatter_mean, scatter_add
from torch_geometric.nn.pool.consecutive import consecutive_cluster
MAPPING_KEY = 'mapping_index'

# Label will be the majority label in each voxel
_INTEGER_LABEL_KEYS = ["y", "instance_labels"]

def group_data(data, cluster=None, unique_pos_indices=None, mode="last", skip_keys=[]):
    """ Group data based on indices in cluster.
    The option ``mode`` controls how data gets agregated within each cluster.

    Parameters
    ----------
    data : Data
        [description]
    cluster : torch.Tensor
        Tensor of the same size as the number of points in data. Each element is the cluster index of that point.
    unique_pos_indices : torch.tensor
        Tensor containing one index per cluster, this index will be used to select features and labels
    mode : str
        Option to select how the features and labels for each voxel is computed. Can be ``last`` or ``mean``.
        ``last`` selects the last point falling in a voxel as the representent, ``mean`` takes the average.
    skip_keys: list
        Keys of attributes to skip in the grouping
    """

    assert mode in ["mean", "last"]
    if mode == "mean" and cluster is None:
        raise ValueError("In mean mode the cluster argument needs to be specified")
    if mode == "last" and unique_pos_indices is None:
        raise ValueError("In last mode the unique_pos_indices argument needs to be specified")

    num_nodes = data.num_nodes
    for key, item in data:
        if bool(re.search("edge", key)):
            raise ValueError("Edges not supported. Wrong data type.")
        if key in skip_keys:
            continue

        if torch.is_tensor(item) and item.size(0) == num_nodes:
            if mode == "last" or key == "batch" \
                    or key == SaveOriginalPosId.KEY \
                    or key == MAPPING_KEY:
                data[key] = item[unique_pos_indices]
            elif mode == "mean":
                is_item_bool = item.dtype == torch.bool
                if is_item_bool:
                    item = item.int()
                if key in _INTEGER_LABEL_KEYS:
                    item_min = item.min()
                    item = torch.nn.functional.one_hot(item - item_min)
                    item = scatter_add(item, cluster, dim=0)
                    data[key] = item.argmax(dim=-1) + item_min
                else:
                    data[key] = scatter_mean(item, cluster, dim=0)
                if is_item_bool:
                    data[key] = data[key].bool()
    return data

class SaveOriginalPosId:
    """ Transform that adds the index of the point to the data object
    This allows us to track this point from the output back to the input data object
    """

    KEY = "origin_id"

    def __init__(self, key=None):
        self.KEY = key if key is not None else self.KEY

    def _process(self, data):
        if hasattr(data, self.KEY):
            return data

        setattr(data, self.KEY, torch.arange(0, data.pos.shape[0]))
        return data

    def __call__(self, data):
        if isinstance(data, list):
            data = [self._process(d) for d in data]
        else:
            data = self._process(data)
        return data

    def __repr__(self):
        return self.__class__.__name__
    
class GridSampling3D:
    """ Clusters points into voxels with size :attr:`size`.
    Parameters
    ----------
    size: float
        Size of a voxel (in each dimension).
    quantize_coords: bool
        If True, it will convert the points into their associated sparse
        coordinates within the grid and store the value into a new
        `coords` attribute.
    mode: string:
        The mode can be either `last` or `mean`.
        If mode is `mean`, all the points and their features within a
        cell will be averaged. If mode is `last`, one random points per
        cell will be selected with its associated features.
    setattr_full_pos: bool
        If True, the input point positions will be saved into a new
        'full_pos' attribute. This memory-costly step may reveal
        necessary for subsequent local feature computation.
    """

    def __init__(self, size, quantize_coords=False, mode="mean", verbose=False,
                 setattr_full_pos=False):
        self._grid_size = size
        self._quantize_coords = quantize_coords
        self._mode = mode
        self._setattr_full_pos = setattr_full_pos
        if verbose:
            log.warning(
                "If you need to keep track of the position of your points, use "
                "SaveOriginalPosId transform before using GridSampling3D.")

            if self._mode == "last":
                log.warning(
                    "The tensors within data will be shuffled each time this "
                    "transform is applied. Be careful that if an attribute "
                    "doesn't have the size of num_points, it won't be shuffled")

    def _process(self, data):
        if self._mode == "last":
            data = shuffle_data(data)

        full_pos = data.pos
        coords = torch.round((data.pos) / self._grid_size)
        if "batch" not in data:
            cluster = grid_cluster(coords, torch.tensor([1, 1, 1]))
        else:
            cluster = voxel_grid(coords, data.batch, 1)
        cluster, unique_pos_indices = consecutive_cluster(cluster)

        data = group_data(data, cluster, unique_pos_indices, mode=self._mode)
        if self._quantize_coords:
            data.coords = coords[unique_pos_indices].int()

        data.grid_size = torch.tensor([self._grid_size])

        # Keep track of the initial full-resolution point cloud for
        # later use. Typically needed for local features computation.
        # However, for obvious memory-wary considerations, it is
        # recommended to delete the 'full_pos' attribute as soon as it
        # is no longer needed.
        if self._setattr_full_pos:
            data.full_pos = full_pos

        return data

    def __call__(self, data):
        if isinstance(data, list):
            data = [self._process(d) for d in data]
        else:
            data = self._process(data)
        return data

    def __repr__(self):
        return "{}(grid_size={}, quantize_coords={}, mode={})".format(
            self.__class__.__name__, self._grid_size, self._quantize_coords, self._mode
        )

In [45]:
start = time()
data_sub = GridSampling3D(size=voxel)(data)
print(f'Data  voxelization at {voxel}m: {time() - start:0.3f}s')
print(f'Number of sampled points: {data_sub.num_nodes} ({data_sub.num_nodes // 10**6:0.2f}M, {100 * data_sub.num_nodes / data.num_nodes:0.1f}%)')

Data  voxelization at 0.05m: 2.496s
Number of loaded points: 2480151 (2.00M, 1.0%)


### Loïc's C implem

In [39]:
import superpoint_transformer.partition.utils.libpoint_utils as point_utils

# WARNING: the pruning must know the number of classes. All labels are 
# offset to account for the -1 unlabeled points !
start = time()
xyz, rgb, labels, dump = point_utils.prune(data.pos.float().numpy(), voxel, (data.rgb * 255).byte().numpy(), data.y.byte().numpy() + 1, np.zeros(1, dtype='uint8'), KITTI360_NUM_CLASSES + 1, 0)
print(f'Data  voxelization at {voxel}m: {time() - start:0.3f}s')
print(f'Number of sampled points: {xyz.shape[0]} ({xyz.shape[0] // 10**6:0.2f}M, {100 * xyz.shape[0] / data.num_nodes:0.1f}%)')

Data  voxelization at 0.05m: 7.412s
Number of sampled points: 2479935 (2.00M, 0.8%)
Voxelization into 3612 x 3216 x 335 grid
Reduced from 3201318 to 2479935 points (77.46%)


In [42]:
xyz.shape[0] / data.num_nodes

0.774660624155426

# Neighbour search

In [4]:
x = data.pos
k = 10

### Sklearn

In [6]:
from sklearn.neighbors import KDTree

start = time()
kdt = KDTree(x.numpy(), leaf_size=30, metric='euclidean')
neighbors = kdt.query(x.numpy(), k=k, return_distance=False)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 15.029s


### FAISS-GPU

In [None]:
import faiss

def find_neighbours(x, y, k=10, ncells=None, nprobes=10):
    # if batch_x is not None or batch_y is not None:
    #     raise NotImplementedError(
    #         "FAISSGPUKNNNeighbourFinder does not support batches yet")

    x = x.view(-1, 1) if x.dim() == 1 else x
    y = y.view(-1, 1) if y.dim() == 1 else y
    x, y = x.contiguous(), y.contiguous()

    # FAISS-GPU consumes numpy arrays
    x_np = x.cpu().numpy()
    y_np = y.cpu().numpy()

    # Initialization
    n_fit = x_np.shape[0]
    d = x_np.shape[1]
    gpu = faiss.StandardGpuResources()

    # Heuristics to prevent k from being too large
    k_max = 1024
    k = min(k, n_fit, k_max)

    # Heuristic to parameterize the number of cells for FAISS index,
    # if not provided
    if ncells is None:
        f1 = 3.5 * np.sqrt(n_fit)
        f2 = 1.6 * np.sqrt(n_fit)
        if n_fit > 2 * 10 ** 6:
            p = 1 / (1 + np.exp(2 * 10 ** 6 - n_fit))
        else:
            p = 0
        ncells = int(p * f1 + (1 - p) * f2)

    # Building a GPU IVFFlat index + Flat quantizer
    torch.cuda.empty_cache()
    quantizer = faiss.IndexFlatL2(d)  # the quantizer index
    index = faiss.IndexIVFFlat(quantizer, d, ncells, faiss.METRIC_L2)  # the main index
    gpu_index_flat = faiss.index_cpu_to_gpu(gpu, 0, index)  # pass index it to GPU
    gpu_index_flat.train(x_np)  # fit the cells to the training set distribution
    gpu_index_flat.add(x_np)

    # Querying the K-NN
    gpu_index_flat.setNumProbes(nprobes)
    return torch.LongTensor(gpu_index_flat.search(y_np, k)[1]).to(x.device)

start = time()
out = find_neighbours(x, x, k=k, ncells=None, nprobes=10)
print(f'Neighbor search: {time() - start:0.3f}s')

### PyKeOps

In [11]:
from pykeops.torch import LazyTensor

start = time()
# K-NN search with KeOps. If the number of points is greater
# than 16 millions, KeOps requires double precision.
xyz_query = x.contiguous()
xyz_search = x.contiguous()
if xyz_search.shape[0] > 1.6e7:
    xyz_query_keops = LazyTensor(xyz_query[:, None, :].double())
    xyz_search_keops = LazyTensor(xyz_search[None, :, :].double())
else:
    xyz_query_keops = LazyTensor(xyz_query[:, None, :])
    xyz_search_keops = LazyTensor(xyz_search[None, :, :])
d_keops = ((xyz_query_keops - xyz_search_keops) ** 2).sum(dim=2)
neighbors = d_keops.argKmin(k, dim=1)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 39.054s


### FLANN

In [8]:
import pyflann

start = time()
flann = pyflann.FLANN()
result, dists = flann.nn(x.numpy(), x.numpy(), k, algorithm="kmeans", branching=32, iterations=7, checks=16)
print(f'Neighbor search: {time() - start:0.3f}s')

ImportError: Cannot load dynamic library. Did you compile FLANN?

## Pytorch Geometric

In [13]:
from torch_geometric.nn import knn

start = time()
out = knn(x, x, k, batch_x=None, batch_y=None, num_workers=1)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 9.466s


In [17]:
start = time()
out = knn(x, x, k, batch_x=None, batch_y=None, num_workers=2)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 9.215s


In [14]:
start = time()
out = knn(x, x, k, batch_x=None, batch_y=None, num_workers=4)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 9.076s


In [16]:
start = time()
out = knn(x, x, k, batch_x=None, batch_y=None, num_workers=8)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 9.438s


In [None]:
x_cuda = x.cuda()
start = time()
out = knn(x_cuda, x_cuda, k, batch_x=None, batch_y=None, num_workers=1)
print(f'Neighbor search: {time() - start:0.3f}s')
del x_cuda

### GriSPy

In [14]:
import grispy as gsp

start = time()
grid = gsp.GriSPy(x.numpy())
dist, ind = grid.nearest_neighbors(x.numpy(), n=k)
print(f'Neighbor search: {time() - start:0.3f}s')

KeyboardInterrupt: 

### FRNN
https://github.com/lxxue/FRNN

In [21]:
sys.path.append(os.path.join(os.path.dirname(file_path), "FRNN"))
import frnn

start = time()
# first time there is no cached grid
dists, idxs, nn, grid = frnn.frnn_grid_points(x.view(1, -1, 3), x.view(1, -1, 3), None, None, k, -1, grid=None, return_nn=False, return_sorted=True)
print(f'Neighbor search: {time() - start:0.3f}s')


# # if points2 and r don't change, we can reuse the grid
# dists, idxs, nn, grid = frnn.frnn_grid_points(
#     points1, points2, lengths1, lengths2, K, r, grid=grid, return_nn=False, return_sorted=True
# )

ImportError: cannot import name '_C' from partially initialized module 'frnn' (most likely due to a circular import) (/home/ign.fr/drobert-admin/projects/FRNN/frnn/__init__.py)

### Pytorch3D

In [25]:
import pytorch3d

start = time()
pytorch3d.ops.knn_points(x.view(1, -1, 3), x.view(1, -1, 3), K=k)
print(f'Neighbor search: {time() - start:0.3f}s')

ModuleNotFoundError: No module named 'pytorch3d'

# Graph computation

In [None]:
# PCP computation
components2, in_component2, tracks_single = pcp(features, graph_nn, args.reg_strength, 10)
# en sachant que pour graph_nn tu n'as besoin que de graph_nn["source"], graph_nn["target"] et graph_nn["edge_weight"]

# Partition

In [None]:
from grid_graph import edge_list_to_forward_star
from cp_kmpp_d0_dist import cp_kmpp_d0_dist

def pcp(features, graph_nn, reg_strength, cutoff, parallel=True, return_intermediate=False, balance=True):
    """
    parallel cut pursuit
    """
    # Convert to forward-star graph representation
    first_edge, adj_vertices, reindex = edge_list_to_forward_star(
        features.shape[0], np.concatenate((graph_nn["source"][:, None], 
        graph_nn["target"][:, None]), 1))
    
    if parallel:
        max_thread = 0
    else:
        max_thread = 1
    
    if return_intermediate:
        Comp, rX, it, Obj, Time, comp_List = cp_kmpp_d0_dist(
            1, np.asfortranarray(features.T), first_edge, adj_vertices,
            edge_weights=reg_strength * graph_nn["edge_weight"][reindex], min_comp_weight=cutoff,
            cp_dif_tol=1e-2, cp_it_max=10, split_damp_ratio=0.7, verbose=False, 
            max_num_threads=max_thread, compute_Com=True, compute_Obj=True, compute_Time=True,
            balance_parallel_split=balance)
        return comp_List, Comp, (Obj, Time, rX)
    else:
        Comp, rX, it, comp_List = cp_kmpp_d0_dist(
            1, np.asfortranarray(features.T), first_edge, adj_vertices,
            edge_weights=reg_strength * graph_nn["edge_weight"][reindex], min_comp_weight=cutoff,
            cp_dif_tol=1e-2, cp_it_max=10, split_damp_ratio=0.7, verbose=False,
            max_num_threads=max_thread, compute_Com=True, balance_parallel_split=balance)
        return comp_List, Comp, []