In [1]:
import numpy as np
import torch
import sys, os

# file_path = os.path.dirname(os.path.abspath(__file__)) # this is for the .py script but does not work in a notebook
file_path = os.path.dirname(os.path.abspath(''))
sys.path.append(file_path)
# sys.path.append(os.path.join(file_path, "grid-graph/python/bin"))
# sys.path.append(os.path.join(file_path, "parallel-cut-pursuit/python/wrappers"))

# Data loading

In [2]:
from time import time
import glob
from superpoint_transformer.data import Data
from superpoint_transformer.datasets.kitti360 import read_kitti360_window
from superpoint_transformer.datasets.kitti360_config import KITTI360_NUM_CLASSES

# DATA_ROOT
DATA_ROOT = '/media/drobert-admin/DATA2'
# DATA_ROOT = '/var/data/drobert'

i_window = 0
all_filepaths = sorted(glob.glob(DATA_ROOT + '/datasets/kitti360/shared/data_3d_semantics/*/static/*.ply'))
filepath = all_filepaths[i_window]

start = time()
data = read_kitti360_window(filepath, semantic=True, instance=False, remap=True)
print(f'Loading data {i_window+1}/{len(all_filepaths)}: {time() - start:0.3f}s')
print(f'Number of loaded points: {data.num_nodes} ({data.num_nodes // 10**6:0.2f}M)')

# Offset labels by 1 to account for unlabelled points -> !!!!!!!!!!!!!!!! IMPORTANT !!!!!!!!!!!!!!!!
data.y[data.y == -1] = KITTI360_NUM_CLASSES
KITTI360_NUM_CLASSES += 1

Loading data 1/342: 0.071s
Number of loaded points: 3201318 (3.00M)


# Voxelization

### Pytorch Geometric

In [4]:
from torch_geometric.nn.pool import voxel_grid

start = time()
data_sub = voxel_grid(data.pos, size=0.1)
print(f'Data  voxelization at {voxel}m: {time() - start:0.3f}s')
print(f'Number of sampled points: {data_sub.shape[0]} ({data_sub.shape[0] / 10**6:0.2f}M, {100 * data_sub.shape[0] / data.num_nodes:0.1f}%)')

Data  voxelization at 0.05m: 0.154s
Number of sampled points: 3201318 (3.00M, 100.0%)


### TorchPoints3D

In [27]:
##### import torch
import re
from torch_geometric.nn.pool import voxel_grid
from torch_cluster import grid_cluster
from torch_scatter import scatter_mean, scatter_add
from torch_geometric.nn.pool.consecutive import consecutive_cluster


def shuffle_data(data):
    """ Shuffle the order of nodes in Data. Only `torch.Tensor` 
    attributes of size `Data.num_nodes` are affected.  
    
    Warning: this modifies the input Data object in-place

    Parameters
    ----------
    data : Data
    """
    num_points = data.pos.shape[0]
    shuffle_idx = torch.randperm(num_points)
    for key in set(data.keys):
        item = data[key]
        if torch.is_tensor(item) and num_points == item.shape[0]:
            data[key] = item[shuffle_idx]
    return data


def group_data(
        data, cluster=None, unique_pos_indices=None, mode="mean", skip_keys=[], 
        bins={}):
    """ Group data based on indices in cluster. The option ``mode`` 
    controls how data gets aggregated within each cluster.
    
    Warning: this modifies the input Data object in-place

    Parameters
    ----------
    data : Data
        [description]
    cluster : torch.Tensor
        Tensor of the same size as the number of points in data. Each 
        element is the cluster index of that point.
    unique_pos_indices : torch.tensor
        Tensor containing one index per cluster, this index will be used
        to select features and labels
    mode : str
        Option to select how the features and labels for each voxel is 
        computed. Can be ``last`` or ``mean``. ``last`` selects the last 
        point falling in a voxel as the representent, ``mean`` takes the
        average.
    skip_keys: list
        Keys of attributes to skip in the grouping
    bins: dict
        Dictionary holding ``{'key': n_bins}`` where ``key`` is a Data 
        attribute for which we would like to aggregate values into an 
        histogram and ``n_bins`` accounts for the corresponding number 
        of bins. This is typically needed when we want to aggregate 
        point labels without losing the distribution, as opposed to 
        majority voting.
    """
    
    # Keys for which voxel aggregation will be based on majority voting
    _VOTING_KEYS = ["y", "instance_labels"]

    # Keys for which voxel aggregation will be based on majority voting
    _LAST_KEYS = ["batch", SaveOriginalPosId.KEY]

    assert mode in ["mean", "last"]
    if mode == "mean" and cluster is None:
        raise ValueError(
            "In mean mode the cluster argument needs to be specified")
    if mode == "last" and unique_pos_indices is None:
        raise ValueError(
            "In last mode the unique_pos_indices argument needs to be specified")
    
    # Save the number of nodes here because the subsequent in-place 
    # modifications will affect it
    num_nodes = data.num_nodes
    
    # Aggregate Data attributes for same-cluster points
    for key, item in data:
        
        # `skip_keys` are not aggregated
        if key in skip_keys:
            continue
        
        # Edges cannot be aggregated
        if bool(re.search("edge", key)):
            raise ValueError("Edges not supported. Wrong data type.")
        
        # Only torch.Tensor attributes of size Data.num_nodes are 
        # considered for aggregation
        if torch.is_tensor(item) and item.size(0) == num_nodes:
                        
            # For 'last' mode, use unique_pos_indices to pick values from 
            # a single point within each cluster. The same behavior is 
            # expected for the _LAST_KEYS
            if mode == "last" or key in _LAST_KEYS:
                data[key] = item[unique_pos_indices]
            
            # For 'mean' mode, the attributes will be aggregated 
            # depending on their nature
            elif mode == "mean":
                
                # If the attribute is a boolean, temporarily convert is 
                # to integer to facilitate aggregation 
                is_item_bool = item.dtype == torch.bool
                if is_item_bool:
                    item = item.int()
                
                # For keys requiring a voting scheme or a histogram
                if key in _VOTING_KEYS or key in bins.keys():
                    
                    assert item.ge(0).all(), "Mean aggregation only supports positive integers"
                    assert item.dtype in [torch.uint8, torch.int, torch.long], "Mean aggregation only supports positive integers"
                                        
                    # Initialization
                    voting = key not in bins.keys()
                    n_bins = item.max() if voting else bins[key]
                    
                    # Convert values to one-hot encoding. Values are 
                    # temporarily offset to 0 to save some memory and 
                    # compute in one-hot encoding and scatter_add
                    offset = item.min()
                    item = torch.nn.functional.one_hot(item - offset)

                    # Count number of occurrence of each value
                    hist = scatter_add(item, cluster, dim=0)
                    N = hist.shape[0]
                    device = hist.device
                        
                    # Prepend 0 columns to the histogram for bins 
                    # removed due to offsetting
                    bins_before = torch.zeros(
                        (N, offset), device=device).long()
                    hist = torch.cat((bins_before, hist), dim=1)
                        
                    # Append columns to the histogram for unobserved 
                    # classes/bins
                    bins_after = torch.zeros(
                        (N, n_bins - hist.shape[1]), device=device).long()
                    hist = torch.cat((hist, bins_after), dim=1)
                    
                    # Either save the histogram or the majority vote
                    data[key] = hist.argmax(dim=-1) if voting else hist
                
                # Standard behavior, where attributes are simply 
                # averaged across the clusters
                else:
                    data[key] = scatter_mean(item, cluster, dim=0)
                    
                # Convert back to boolean if need be 
                if is_item_bool:
                    data[key] = data[key].bool()
                    
    return data


class SaveOriginalPosId:
    """Adds the index of the point to the Data object attributes. This 
    allows tracking this point from the output back to the input
    data object
    """

    KEY = "origin_id"

    def __init__(self, key=None):
        self.KEY = key if key is not None else self.KEY

    def _process(self, data):
        if hasattr(data, self.KEY):
            return data

        setattr(data, self.KEY, torch.arange(0, data.pos.shape[0]))
        return data

    def __call__(self, data):
        if isinstance(data, list):
            data = [self._process(d) for d in data]
        else:
            data = self._process(data)
        return data

    def __repr__(self):
        return self.__class__.__name__


class GridSampling3D:
    """ Clusters 3D points into voxels with size :attr:`size`.
    Parameters
    ----------
    size: float
        Size of a voxel (in each dimension).
    quantize_coords: bool
        If True, it will convert the points into their associated sparse
        coordinates within the grid and store the value into a new
        `coords` attribute.
    mode: string:
        The mode can be either `last` or `mean`.
        If mode is `mean`, all the points and their features within a
        cell will be averaged. If mode is `last`, one random points per
        cell will be selected with its associated features.
     bins: dict
        Dictionary holding ``{'key': n_bins}`` where ``key`` is a Data 
        attribute for which we would like to aggregate values into an 
        histogram and ``n_bins`` accounts for the corresponding number 
        of bins. This is typically needed when we want to aggregate 
        point labels without losing the distribution, as opposed to 
        majority voting.
    inplace: bool
        Whether the input Data object should be modified in-place
    verbose: bool
        Verbosity
    """

    def __init__(
            self, size, quantize_coords=False, mode="mean", bins={}, 
            inplace=False, verbose=False):
        self.grid_size = size
        self.quantize_coords = quantize_coords
        self.mode = mode
        self.bins = bins
        self.inplace = inplace
        if verbose:
            log.warning(
                "If you need to keep track of the position of your points, use "
                "SaveOriginalPosId transform before using GridSampling3D.")

            if self.mode == "last":
                log.warning(
                    "The tensors within data will be shuffled each time this "
                    "transform is applied. Be careful that if an attribute "
                    "doesn't have the size of num_nodes, it won't be shuffled")

    def _process(self, data_in):
        # In-place option will modify the input Data object directly
        data = data_in if self.inplace else data_in.clone()
        
        # If the aggregation mode is 'last', shuffle the point order.
        # Note that voxelization of point attributes will be stochastic
        if self.mode == "last":
            data = shuffle_data(data)
        
        # Convert point coordinates to the voxel grid coordinates
        coords = torch.round((data.pos) / self.grid_size)
        
        # Match each point with a voxel identifier
        if "batch" not in data:
            cluster = grid_cluster(coords, torch.ones(3, device=coords.device))
        else:
            cluster = voxel_grid(coords, data.batch, 1)
            
        # Reindex the clusters to make sure the indices used are 
        # consecutive. Basically, we do not want cluster indices to span 
        # [0, i_max] without all in-between indices to be used, because
        # this will affect the speed and output size of torch_scatter 
        # operations 
        cluster, unique_pos_indices = consecutive_cluster(cluster)
        
        # Perform voxel aggregation 
        data = group_data(
            data, cluster, unique_pos_indices, mode=self.mode, bins=self.bins)
        
        # Optionally convert quantize the coordinates. This is useful 
        # for sparse convolution models 
        if self.quantize_coords:
            data.coords = coords[unique_pos_indices].int()
        
        # Save the grid size in the Data attributes
        data.grid_size = torch.tensor([self.grid_size])

        return data

    def __call__(self, data):
        if isinstance(data, list):
            data = [self._process(d) for d in data]
        else:
            data = self._process(data)
        return data

    def __repr__(self):
        return "{}(grid_size={}, quantize_coords={}, mode={})".format(
            self.__class__.__name__, self.grid_size, self.quantize_coords, 
            self.mode)

In [6]:
# CPU
start = time()
data_sub = GridSampling3D(size=voxel)(data)
print(f'Data  voxelization at {voxel}m: {time() - start:0.3f}s')
print(f'Number of sampled points: {data_sub.num_nodes} ({data_sub.num_nodes / 10**6:0.2f}M, {100 * data_sub.num_nodes / data.num_nodes:0.1f}%)')

Data  voxelization at 0.05m: 2.559s
Number of sampled points: 2480151 (2.00M, 77.5%)


In [318]:
# GPU
torch.cuda.synchronize()
start = time()
data_sub = GridSampling3D(size=voxel)(data.cuda()).cpu()
torch.cuda.synchronize()
print(f'Data  voxelization at {voxel}m: {time() - start:0.3f}s')
print(f'Number of sampled points: {data_sub.num_nodes} ({data_sub.num_nodes / 10**6:0.2f}M, {100 * data_sub.num_nodes / data.num_nodes:0.1f}%)')

Data  voxelization at 0.05m: 0.120s
Number of sampled points: 2480168 (2.00M, 77.5%)


### SPG C implem

In [5]:
import superpoint_transformer.partition.utils.libpoint_utils as point_utils

# WARNING: the pruning must know the number of classes. All labels are 
# offset to account for the -1 unlabeled points !
start = time()
xyz, rgb, labels, dump = point_utils.prune(data.pos.float().numpy(), voxel, (data.rgb * 255).byte().numpy(), data.y.byte().numpy() + 1, np.zeros(1, dtype='uint8'), KITTI360_NUM_CLASSES + 1, 0)
print(f'Data  voxelization at {voxel}m: {time() - start:0.3f}s')
print(f'Number of sampled points: {xyz.shape[0]} ({xyz.shape[0] // 10**6:0.2f}M, {100 * xyz.shape[0] / data.num_nodes:0.1f}%)')

Data  voxelization at 0.05m: 8.355s
Number of sampled points: 2479935 (2.00M, 77.5%)


So it seems the C-based voxelization is not that fast. Can we somehow make it faster with more CPU cores ? Otherwise, will fallback to a custom implementation based on TP3D or PyG and keeping track of the in-voxel label distribution.

And even increasing the number of CPU cores (on AI4GEO) gave the same results.

The fastest is GPU-based TP3D-based computation.

### Final

In [3]:
from superpoint_transformer.transforms import GridSampling3D

voxel = 0.05
# voxel = 1

# GPU
torch.cuda.synchronize()
start = time()
n_in = data.num_nodes
data = GridSampling3D(size=voxel, bins={'y': KITTI360_NUM_CLASSES})(data.cuda()).cpu()
torch.cuda.synchronize()
print(f'Data voxelization at {voxel}m: {time() - start:0.3f}s')
print(f'Number of sampled points: {data.num_nodes} ({data.num_nodes / 10**6:0.2f}M, {100 * data.num_nodes / n_in:0.1f}%)')

Data voxelization at 0.05m: 1.895s
Number of sampled points: 2480168 (2.48M, 77.5%)


# Neighbour search

### Sklearn

In [6]:
from sklearn.neighbors import KDTree

start = time()
kdt = KDTree(x.numpy(), leaf_size=30, metric='euclidean')
neighbors = kdt.query(x.numpy(), k=k, return_distance=False)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 15.029s


### FAISS-GPU
```
conda install -c pytorch faiss-gpu cudatoolkit=10.2
pip install faiss-gpu cudatoolkit==10.2
```

In [None]:
import faiss

def find_neighbours(x, y, k=10, ncells=None, nprobes=10):
    # if batch_x is not None or batch_y is not None:
    #     raise NotImplementedError(
    #         "FAISSGPUKNNNeighbourFinder does not support batches yet")

    x = x.view(-1, 1) if x.dim() == 1 else x
    y = y.view(-1, 1) if y.dim() == 1 else y
    x, y = x.contiguous(), y.contiguous()

    # FAISS-GPU consumes numpy arrays
    x_np = x.cpu().numpy()
    y_np = y.cpu().numpy()

    # Initialization
    n_fit = x_np.shape[0]
    d = x_np.shape[1]
    gpu = faiss.StandardGpuResources()

    # Heuristics to prevent k from being too large
    k_max = 1024
    k = min(k, n_fit, k_max)

    # Heuristic to parameterize the number of cells for FAISS index,
    # if not provided
    if ncells is None:
        f1 = 3.5 * np.sqrt(n_fit)
        f2 = 1.6 * np.sqrt(n_fit)
        if n_fit > 2 * 10 ** 6:
            p = 1 / (1 + np.exp(2 * 10 ** 6 - n_fit))
        else:
            p = 0
        ncells = int(p * f1 + (1 - p) * f2)

    # Building a GPU IVFFlat index + Flat quantizer
    torch.cuda.empty_cache()
    quantizer = faiss.IndexFlatL2(d)  # the quantizer index
    index = faiss.IndexIVFFlat(quantizer, d, ncells, faiss.METRIC_L2)  # the main index
    gpu_index_flat = faiss.index_cpu_to_gpu(gpu, 0, index)  # pass index it to GPU
    gpu_index_flat.train(x_np)  # fit the cells to the training set distribution
    gpu_index_flat.add(x_np)

    # Querying the K-NN
    gpu_index_flat.setNumProbes(nprobes)
    return torch.LongTensor(gpu_index_flat.search(y_np, k)[1]).to(x.device)

start = time()
out = find_neighbours(x, x, k=k, ncells=None, nprobes=10)
print(f'Neighbor search: {time() - start:0.3f}s')

### PyKeOps
```
pip install pykeops
```

In [11]:
from pykeops.torch import LazyTensor

start = time()
# K-NN search with KeOps. If the number of points is greater
# than 16 millions, KeOps requires double precision.
xyz_query = x.contiguous()
xyz_search = x.contiguous()
if xyz_search.shape[0] > 1.6e7:
    xyz_query_keops = LazyTensor(xyz_query[:, None, :].double())
    xyz_search_keops = LazyTensor(xyz_search[None, :, :].double())
else:
    xyz_query_keops = LazyTensor(xyz_query[:, None, :])
    xyz_search_keops = LazyTensor(xyz_search[None, :, :])
d_keops = ((xyz_query_keops - xyz_search_keops) ** 2).sum(dim=2)
neighbors = d_keops.argKmin(k, dim=1)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 39.054s


### FLANN
```
conda install -c conda-forge pyflann -y
```

In [8]:
import pyflann

start = time()
flann = pyflann.FLANN()
result, dists = flann.nn(x.numpy(), x.numpy(), k, algorithm="kmeans", branching=32, iterations=7, checks=16)
print(f'Neighbor search: {time() - start:0.3f}s')

ImportError: Cannot load dynamic library. Did you compile FLANN?

### Pytorch Geometric

In [13]:
from torch_geometric.nn import knn

start = time()
out = knn(x, x, k, batch_x=None, batch_y=None, num_workers=1)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 9.466s


In [17]:
start = time()
out = knn(x, x, k, batch_x=None, batch_y=None, num_workers=2)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 9.215s


In [14]:
start = time()
out = knn(x, x, k, batch_x=None, batch_y=None, num_workers=4)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 9.076s


In [16]:
start = time()
out = knn(x, x, k, batch_x=None, batch_y=None, num_workers=8)
print(f'Neighbor search: {time() - start:0.3f}s')

Neighbor search: 9.438s


In [None]:
x_cuda = x.cuda()
start = time()
out = knn(x_cuda, x_cuda, k, batch_x=None, batch_y=None, num_workers=1)
print(f'Neighbor search: {time() - start:0.3f}s')
del x_cuda

### GriSPy
```
pip install grispy
```

In [14]:
import grispy as gsp

start = time()
grid = gsp.GriSPy(x.numpy())
dist, ind = grid.nearest_neighbors(x.numpy(), n=k)
print(f'Neighbor search: {time() - start:0.3f}s')

KeyboardInterrupt: 

### FRNN = Final
https://github.com/lxxue/FRNN

In [None]:
from superpoint_transformer.partition.FRNN import frnn
import torch


def _search_outliers(
        xyz_query, xyz_search, k_min, r_max=1, recursive=False, q_in_s=False):
    """
    Optionally recursive outlier search. The `xyz_query` and `xyz_search`
    Search for points with less than `k_min` neighbors within a radius 
    of `r_max`. 
    
    Since removing outliers may cause some points to become outliers 
    themselves, this problem can be tackled with the `recursive` option. 
    Note that this recursive search holds no garantee of reasonable 
    convergence as one could design a point cloud for given `k_min` and 
    `r_max` whose points would all recursively end up as outliers.  
    """
    # Data initialization
    xyz_query = xyz_query.view(1, -1, 3)
    xyz_search = xyz_search.view(1, -1, 3)
    device = xyz_query.device
    
    # KNN on GPU. Actual neighbor search now
    neighbors = frnn.frnn_grid_points(
        xyz_query, xyz_search, K=k_min + q_in_s, r=r_max)[1]
    
    # If the Query points are included in the Search points, remove each
    # point from its own neighborhood
    if q_in_s:
        neighbors = neighbors[0][:, 1:]
    
    # Get the number of found neighbors for each point. Indeed, 
    # depending on the cloud properties and the chosen K and radius, 
    # some points may receive "-1" neighbors
    n_found_nn = (neighbors != -1).sum(dim=1)

    # Identify points which have less than k_min neighbor. Those are 
    # treated as outliers
    mask_outliers = n_found_nn < k_min
    idx_outliers = torch.where(mask_outliers)[0]
    idx_inliers = torch.where(~mask_outliers)[0]
    
    # Exit here if not recursively searching for outliers 
    if not recursive:
        return idx_outliers, idx_inliers
    
    # Identify the points affected by the removal of the outliers. Those
    # inliers are potential outliers
    idx_potential = torch.where(
        torch.isin(neighbors[idx_inliers], idx_outliers).any(dim=1))[0]
        
    # Exit here if there are no potential new outliers among the inliers
    if idx_potential.shape[0] == 0:
        return idx_outliers, idx_inliers
    
    # Recursviely search actual outliers among the potential
    xyz_query_sub = xyz_query[0, idx_inliers[idx_potential]]
    xyz_search_sub = xyz_search[0, idx_inliers]
    idx_outliers_sub, idx_inliers_sub = _search_outliers(
        xyz_query_sub, xyz_search_sub, k_min, r_max=r_max, recursive=True, 
        q_in_s=True)
    
    # Update the outliers mask
    mask_outliers[idx_inliers[idx_potential][idx_outliers_sub]] = True
    idx_outliers = torch.where(mask_outliers)[0]
    idx_inliers = torch.where(~mask_outliers)[0]
    
    return idx_outliers, idx_inliers


def search_outliers(data, k_min, r_max=1, recursive=False):
    """
    Search for points with less than `k_min` neighbors within a radius 
    of `r_max`. 
    
    Since removing outliers may cause some points to become outliers 
    themselves, this problem can be tackled with the `recursive` option. 
    Note that this recursive search holds no garantee of reasonable 
    convergence as one could design a point cloud for given `k_min` and 
    `r_max` whose points would all recursively end up as outliers.  
    """       
    # Actual outlier search, optionally recursive
    idx_outliers, idx_inliers = _search_outliers(
        data.pos, data.pos, k_min, r_max=r_max, recursive=recursive, 
        q_in_s=True)
    
    # Create a Data object for the inliers and outliers    
    # Save the index for these isolated points in the Data object. This
    # will help properly handle neighborhoods, features and adjacency  
    # graph for those specific points. 
    # NB: it is important this attribute follows the "*index" naming 
    # convention, see:
    # https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html
    data_in = Data()
    data_out = Data(outliers_index=idx_outliers)
    for key, item in data:
        if torch.is_tensor(item) and item.size(0) == data.num_nodes:
            data_in[key] = data[key][idx_inliers]
            data_out[key] = data[key][idx_outliers]

    return data_in, data_out


def oversample_partial_neighborhoods(neighbors, distances, k):
    """
    Oversample partial neighborhoods with less than k points. Missing 
    neighbors are indicated by the "-1" index.
    
    Remarks
      - Neighbors and distances are assumed to be sorted in order of 
      increasing distance
      - All neighbors are assumed to have at least one valid neighbor. 
      See `search_outliers` to remove points with not enough neighbors 
    """
    # Initialization
    assert neighbors.dim() == distances.dim() == 2
    device = neighbors.device
    
    # Get the number of found neighbors for each point. Indeed, 
    # depending on the cloud properties and the chosen K and radius, 
    # some points may receive `-1` neighbors
    n_found_nn = (neighbors != -1).sum(dim=1)
        
    # Identify points which have more than k_min and less than k 
    # neighbors within R. For those, we oversample the neighbors to 
    # reach k
    idx_partial = torch.where(n_found_nn < k)[0]
    neighbors_partial = neighbors[idx_partial]
    distances_partial = distances[idx_partial]

    # Since the neighbors are sorted by increasing distance, the missing 
    # neighbors will always be the last ones. This helps finding their 
    # number and position, for oversampling.

    #*******************************************************************
    # The above statement is actually INCORRECT because the outlier 
    # removal may produce "-1" neighbors at unexpected positions. So 
    # either we manage to treat this in a clean vectorized way, or we 
    # fall back to the 2-searches solution...
    # Honestly, this feels like it is getting out of hand, let's keep 
    # things simple, since we are not going to save so much computation 
    # time with KNN wrt the partition.
    #*******************************************************************
    
    # For each missing neighbor, compute the size of the discrete set to 
    # oversample from.
    n_valid = n_found_nn[idx_partial].repeat_interleave(
        k - n_found_nn[idx_partial])

    # Compute the oversampling row indices.
    idx_x_sampling = torch.arange(neighbors_partial.shape[0], device=device
        ).repeat_interleave(k - n_found_nn[idx_partial])
    
    # Compute the oversampling column indices. The 0.9999 factor is a 
    # security to handle the case where torch.rand is to close to 1.0, 
    # which would yield incorrect sampling coordinates that would in 
    # result in sampling '-1' indices (ie all we try to avoid here)
    idx_y_sampling = (n_valid * torch.rand(
        n_valid.shape[0], device=device) * 0.9999).floor().long()

    # Apply the oversampling
    idx_missing = torch.where(neighbors_partial == -1)
    neighbors_partial[idx_missing] = neighbors_partial[
        idx_x_sampling, idx_y_sampling]
    distances_partial[idx_missing] = distances_partial[
        idx_x_sampling, idx_y_sampling]

    # Restore the oversampled neighborhods with the rest
    neighbors[idx_partial] = neighbors_partial
    distances[idx_partial] = distances_partial
    
    return neighbors, distances
    

def search_neighbors(data, k, r_max=1):
    # Data initialization
    xyz_query = data.pos.view(1, -1, 3)
    xyz_search = data.pos.view(1, -1, 3)
    
#     #--------------------------------
#     # KNN on GPU. Search for outliers first
#     _, neighbors, _, _ = frnn.frnn_grid_points(
#         xyz_query, xyz_search, K=k_min + 1, r=r_max)
    
#     # Remove each point from its own neighborhood
#     neighbors = neighbors[0][:, 1:]
    
#     # Get the number of found neighbors for each point. Indeed, 
#     # depending on the cloud properties and the chosen K and radius, 
#     # some points may receive `-1` neighbors
#     n_found_nn = (neighbors != -1).sum(dim=1)

#     # Identify points which have less than k_min neighbors within R. 
#     # Those are treated as outliers and will be discarded
#     idx_isolated = torch.where(n_found_nn < k_min)[0]
    
#     # Save the outliers in a separate Data object
#     outliers = Data(
#         pos=data.pos[idx_isolated], rgb=data.rgb[idx_isolated], 
#         y=data.y[idx_isolated], idx_isolated=idx_isolated)
    
#     # KNN on GPU. Search for outliers first
#     _, neighbors, _, _ = frnn.frnn_grid_points(
#         xyz_query, xyz_search, K=k_min + 1, r=r_max)
#     #--------------------------------
    
    # KNN on GPU. Actual neighbor search now
    distances, neighbors, _, _ = frnn.frnn_grid_points(
        xyz_query, xyz_search, K=k + 1, r=r_max)
    
    # Remove each point from its own neighborhood
    neighbors = neighbors[0][:, 1:]
    distances = distances[0][:, 1:]

    # Oversample the neighborhoods where less than k points were found
    neighbors, distances = oversample_partial_neighborhoods(
        neighbors, distances, k)
    
    # Store the neighbors and distances as a Data object attribute
    data.neighbors = neighbors.cpu()
    data.distances = distances.cpu()
    
    # Save the index for these isolated points in the Data object. This
    # will help properly handle neighborhoods, features and adjacency  
    # graph for those specific points. 
    # NB: it is important this attribute follows the "*index" naming 
    # convention, see:
    # https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html
    # data.isolated_index = idx_isolated.cpu()
    
    return data

# IMPORTANT !!!
#   - points with no neighbors within radius -> set to 0-feature !

### Final

In [4]:
from superpoint_transformer.transforms import search_outliers, search_neighbors

radius = 1
# radius = 10
k_min = 5
k_feat = 30
k_adjacency = 10

data = data.cuda()

torch.cuda.synchronize()
start = time()
data, data_outliers = search_outliers(data, k_min, r_max=radius, recursive=True)
data_outliers = data_outliers.cpu()
torch.cuda.synchronize()
print(f'Outliers search: {time() - start:0.3f}s')

torch.cuda.synchronize()
start = time()
data = search_neighbors(data, k_feat, r_max=radius)
# Make sure all points have k neighbors (no "-1" missing neighbors)
assert (data.neighbors != -1).all(), "Some points have incomplete neighborhoods, make sure to remove the outliers to avoid this issue."
torch.cuda.synchronize()
print(f'Neighbor search: {time() - start:0.3f}s')

Outliers search: 0.584s
Neighbor search: 1.701s


### Pytorch3D
```
pip install -U fvcore
pip install -U iopath
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html
```

In [25]:
import pytorch3d

start = time()
pytorch3d.ops.knn_points(x.view(1, -1, 3), x.view(1, -1, 3), K=k)
print(f'Neighbor search: {time() - start:0.3f}s')

ModuleNotFoundError: No module named 'pytorch3d'

So it seems FRNN on GPU is the clear winner here !

# !!! ___ CAREFUL WITH CPU-CUDA MOVES ___ !!! 

# Geometric features computation

### SPG C implem

In [9]:
import superpoint_transformer.partition.utils.libpoint_utils as point_utils

data = data.cpu()

torch.cuda.synchronize()
start = time()
geof = point_utils.compute_geometric_features(
    data.pos.numpy(), data.neighbors.flatten().numpy().astype('uint32'), 
    np.arange(data.pos.shape[0] + 1).astype('uint32') * k_feat, False).astype('float32')  # IMPORTANT CAREFUL WITH UINT32 = 4G MAX
print(f'Geometric features: {time() - start:0.3f}s')

Geometric features: 1.909s


This is the fasest way of computing the geometric features. Surprisingly, the CPU implementation is faster than the TP3D-based GPU one.

In [None]:
import superpoint_transformer.partition.utils.libpoint_utils as point_utils

def compute_pointfeatures(
        data, pos=True, radius=5, rgb=True, linearity=True, planarity=True,
        scattering=True, verticality=True, normal=True, length=False, 
        surface=False, volume=False):
    """ Compute the pointwise features that will be used for the 
    partition.
    
    All local geometric features assume the input ``Data`` has a 
    ``neighbors`` attribute, holding a ``(num_nodes, k)`` tensor of 
    indices. All k neighbors will be used for local geometric features 
    computation.
    
    Parameters
    ----------
    pos: bool
        Use point position.
    radius: bool
        Radius used to scale the point position features, to mitigate 
        the maximum superpoint size.
    rgb: bool
        Use rgb color. Assumes Data.rgb holds either [0, 1] flaots or 
        [0, 255] integers
    linearity: bool
        Use local linearity. Assumes ``Data.neighbors``.
    lanarity: bool
        Use local lanarity. Assumes ``Data.neighbors``.
    scattering: bool
        Use local scattering. Assumes ``Data.neighbors``.
    verticality: bool
        Use local verticality. Assumes ``Data.neighbors``.
    normal: bool
        Use local normal. Assumes ``Data.neighbors``.
    length: bool
        Use local length. Assumes ``Data.neighbors``.
    surface: bool
        Use local surface. Assumes ``Data.neighbors``.
    volume: bool
        Use local volume. Assumes ``Data.neighbors``.
    """
    features = []
    
    # Add xyz normalized. The scaling factor drives the maximum cluster
    # size the partition may produce
    if 'pos' in data.keys:
        features.append(data.pos / radius)
    
    # Add rgb to the features. If colors are stored in int, we assume 
    # they are encoded in  [0, 255] and normalize them. Otherwise, we 
    # assume they have already been [0, 1] normalized
    if rgb:
        f = data.rgb
        if f.type in [torch.uint8, torch.int, torch.long]:
            f = f.float() / 255
        features.append(f)
    
    # Add local geometric features
    if any((linearity, planarity, scattering, verticality, normal)):
        
        # Prepare data for numpy boost interface
        xyz = data.pos.cpu().numpy()
        nn = data.neighbors.flatten().cpu().numpy().astype('uint32')  # !!!! IMPORTANT CAREFUL WITH UINT32 = 4 BILLION points MAXIMUM !!!!
        k = data.neighbors.shape[1]
        nn_ptr = np.arange(xyz.shape[0] + 1).astype('uint32') * k  # !!!! IMPORTANT CAREFUL WITH UINT32 = 4 BILLION points MAXIMUM !!!!
        
        # C++ geometric features computation on CPU
        f = point_utils.compute_geometric_features(xyz, nn, nn_ptr, False)
        f = torch.from_numpy(f.astype('float32'))
        
        # Heuristic to increase the importance of verticality
        f[:, 3] *= 2
        
        # Select only required features
        mask = (
            [linearity, planarity, scattering, verticality] 
            + [normal] * 3 
            + [length, surface, volume])
        features.append(f[:, mask].to(data.pos.device))
        
    # Save all features in the Data.x attribute
    data.x = torch.cat(features, dim=1).to(data.pos.device)
    
    return data

### TP3D-based

In [7]:
import torch
from superpoint_transformer.data import Data

def batch_pca(xyz):
    """
    Compute the PCA of a batch of point clouds of size (*, N, M).
    """
    assert 2 <= xyz.dim() <= 3
    xyz = xyz.unsqueeze(0) if xyz.dim() == 2 else xyz

    pos_centered = xyz - xyz.mean(dim=1).unsqueeze(1)
    cov_matrix = pos_centered.transpose(1, 2).bmm(pos_centered) / pos_centered.shape[1]
    eigenval, eigenvect = torch.linalg.eigh(cov_matrix)

    # If Nan values are computed, return equal eigenvalues and
    # Identity eigenvectors
    idx_nan = torch.where(torch.logical_and(
        eigenval.isnan().any(1), eigenvect.flatten(1).isnan().any(1)))
    eigenval[idx_nan] = torch.ones(3, dtype=eigenval.dtype, device=xyz.device)
    eigenvect[idx_nan] = torch.eye(3, dtype=eigenvect.dtype, device=xyz.device)

    # Precision errors may cause close-to-zero eigenvalues to be
    # negative. Hard-code these to zero
    eigenval[torch.where(eigenval < 0)] = 0

    return eigenval, eigenvect


class PCAComputePointwise(object):
    """
    Compute PCA for the local neighborhood of each point in the cloud.

    Input data is expected to be stored in DENSE format.

    Results are saved in `eigenvalues` and `eigenvectors` attributes.
    `data.eigenvalues` is a tensor
    :math:`(\lambda_1, \lambda_2, \lambda_3)` such that
    :math:`\lambda_1 \leq \lambda_2 \leq \lambda_3`.
    `data.eigenvectors` is 1x9 tensor containing the eigenvectors
    associated with `data.eigenvalues`, concatenated in the same order.

    Parameters
    ----------
    num_neighbors: int, optional
        Controls the maximum number of neighbors on which to compute
        PCA. If `r=None`, `num_neighbors` will be used as K for
        K-nearest neighbor search. Otherwise, `num_neighbors` will be
        the maximum number of neighbors used in radial neighbor search.
    r: float, optional
        If not `None`, neighborhoods will be computed with a
        radius-neighbor approach. If `None`, K-nearest neighbors will
        be used.
    use_full_pos: bool, optional
        If True, the neighborhood search will be carried on the point
        positions found in the `data.full_pos`. An error will be raised
        if data carries no such attribute. See `GridSampling3D` for
        producing `data.full_pos`.
        If False, the neighbor search will be computed on `data.pos`.
    use_cuda: bool, optional
        If True, the computation will be carried on CUDA.
    workers: int, optional
        If not `None`, the features computation will be distributed
        across the provided number of workers.
    """

    def __init__(
            self, num_neighbors=40, r=None, use_full_pos=False, use_cuda=False,
            use_faiss=True, ncells=None, nprobes=10, chunk_size=1000000):
        self.num_neighbors = num_neighbors
        self.r = r
        self.use_full_pos = use_full_pos
        self.use_cuda = use_cuda and torch.cuda.is_available()
        self.use_faiss = use_faiss and torch.cuda.is_available()
        self.ncells = ncells
        self.nprobes = nprobes
        self.chunk_size = chunk_size

    def _process(self, data: Data):
        assert getattr(data, 'pos', None) is not None, \
            "Data must contain a 'pos' attribute."
        assert not self.use_full_pos \
               or getattr(data, 'full_pos', None) is not None, \
            "Data must contain a 'full_pos' attribute."

        # Recover the query and search clouds
        xyz_query = data.pos
        xyz_search = data.full_pos if self.use_full_pos else data.pos

        # Move computation to CUDA if required
        input_device = xyz_query.device
        if self.use_cuda and not xyz_query.is_cuda and not self.use_faiss:
            xyz_query = xyz_query.cuda()
            xyz_search = xyz_search.cuda()

        # Compute the neighborhoods
        if self.r is not None:
            # Radius-NN search with torch_points_kernel
            sampler = RadiusNeighbourFinder(
                self.r, self.num_neighbors, conv_type='DENSE')
            neighbors = sampler.find_neighbours(
                xyz_search.unsqueeze(0), xyz_query.unsqueeze(0))[0]
        elif self.use_faiss:
            # K-NN search with FAISS
            nn_finder = FAISSGPUKNNNeighbourFinder(
                self.num_neighbors, ncells=self.ncells, nprobes=self.nprobes)
            neighbors = nn_finder(xyz_search, xyz_query, None, None)
        else:
            # K-NN search with KeOps. If the number of points is greater
            # than 16 millions, KeOps requires double precision.
            xyz_query = xyz_query.contiguous()
            xyz_search = xyz_search.contiguous()
            if xyz_search.shape[0] > 1.6e7:
                xyz_query_keops = LazyTensor(xyz_query[:, None, :].double())
                xyz_search_keops = LazyTensor(xyz_search[None, :, :].double())
            else:
                xyz_query_keops = LazyTensor(xyz_query[:, None, :])
                xyz_search_keops = LazyTensor(xyz_search[None, :, :])
            d_keops = ((xyz_query_keops - xyz_search_keops) ** 2).sum(dim=2)
            neighbors = d_keops.argKmin(self.num_neighbors, dim=1)
            # raise NotImplementedError(
            #     "Fast K-NN search has not been implemented yet. Please "
            #     "consider using radius search instead.")

        # Compute PCA for each neighborhood
        # Note: this is surprisingly slow on GPU, so better run on CPU
        eigenvalues = []
        eigenvectors = []
        n_chunks = math.ceil(neighbors.shape[0] / self.chunk_size)
        for i in range(n_chunks):
            xyz_neigh_batch = xyz_search[
                neighbors[i * self.chunk_size: (i + 1) * self.chunk_size]]
            eval, evec = batch_pca(xyz_neigh_batch.cpu())
            evec = evec.transpose(2, 1).flatten(1)
            eigenvalues.append(eval)
            eigenvectors.append(evec)
        eigenvalues = torch.cat(eigenvalues, dim=0)
        eigenvectors = torch.cat(eigenvectors, dim=0)

        # Save eigendecomposition results in data attributes
        data.eigenvalues = eigenvalues.to(input_device)
        data.eigenvectors = eigenvectors.to(input_device)

        return data

    def __call__(self, data):
        if isinstance(data, list):
            data = [self._process(d) for d in tq(data)]
        else:
            data = self._process(data)
        return data

    def __repr__(self):
        attr_repr = ', '.join([f'{k}={v}' for k, v in self.__dict__.items()])
        return f'{self.__class__.__name__}({attr_repr})'


class EigenFeatures(object):
    """
    Compute local geometric features based on local eigenvalues and
    eigenvectors.

    The following local geometric features are computed and saved in
    dedicated data attributes: `normal`, `scattering`, `linearity` and
    `planarity`. The formulation of those is inspired from
    "Hierarchical extraction of urban objects from mobile laser
    scanning data" [Yang et al. 2015]

    Data is expected to carry `eigenvectors` and `eigenvectors`
    attributes:
    `data.eigenvalues` is a tensor
    :math:`(\lambda_1, \lambda_2, \lambda_3)` such that
    :math:`\lambda_1 \leq \lambda_2 \leq \lambda_3`.
    `data.eigenvectors` is 1x9 tensor containing the eigenvectors
    associated with `data.eigenvalues`, concatenated in the same order.
    See `PCAComputePointwise` for generating those.

    Parameters
    ----------
    normal: bool, optional
        If True, the normal to the local surface will be computed.
    linearity: bool, optional
        If True, the local linearity will be computed.
    planarity: bool, optional
        If True, the local planarity will be computed.
    scattering: bool, optional
        If True, the local scattering will be computed.
    temperature: float, optional
        If set to a float value, the returned features will be run
        through a scaled softmax with temperature being the scale. Set
        to None by default.
    """

    def __init__(self, normal=True, linearity=True, planarity=True,
                 scattering=True, verticality=True, temperature=None):
        self.normal = normal
        self.linearity = linearity
        self.planarity = planarity
        self.scattering = scattering
        self.verticality = verticality
        self.temperature = temperature

    def _process(self, data: Data):
        assert getattr(data, 'eigenvalues', None) is not None, \
            "Data must contain an 'eigenvalues' attribute."
        assert getattr(data, 'eigenvectors', None) is not None, \
            "Data must contain an 'eigenvectors' attribute."

        if self.normal:
            # The normal is the eigenvector carried by the smallest
            # eigenvalue
            data.normal = data.eigenvectors[:, :3]

        # Eigenvalues: 0 <= l0 <= l1 <= l2
        # Following, [Yang et al. 2015] we use the sqrt of eigenvalues
        v0 = data.eigenvalues[:, 0].sqrt().squeeze()
        v1 = data.eigenvalues[:, 1].sqrt().squeeze()
        v2 = data.eigenvalues[:, 2].sqrt().squeeze() + 1e-6
        
        e0 = eigenvectors[:, :, 0].abs() * eigenvalues[:, [0]]
        e1 = eigenvectors[:, :, 1].abs() * eigenvalues[:, [1]]
        e2 = eigenvectors[:, :, 2].abs() * eigenvalues[:, [2]]
        u = e0 + e1 + e2

        # Compute the eigen features
        linearity = (v2 - v1) / v2
        planarity = (v1 - v0) / v2
        scattering = v0 / v2
        verticality = u[:, 2] / torch.linalg.norm(u, dim=1)

        # Compute the softmax version of the features, for more
        # opinionated geometric information. As a heuristic, set
        # temperature=5 for clouds of 30 points or more.
        if self.temperature:
            values = (self.temperature * torch.cat([
                linearity.view(-1, 1),
                planarity.view(-1, 1),
                scattering.view(-1, 1)], dim=1)).exp()
            values = values / values.sum(dim=1).view(-1, 1)
            linearity = values[:, 0]
            planarity = values[:, 1]
            scattering = values[:, 2]

        if self.linearity:
            data.linearity = linearity

        if self.planarity:
            data.planarity = planarity

        if self.scattering:
            data.scattering = scattering
        
        if self.verticality:
            data.verticality = verticality

        return data

    def __call__(self, data):
        if isinstance(data, list):
            data = [self._process(d) for d in data]
        else:
            data = self._process(data)
        return data

    def __repr__(self):
        attr_repr = ', '.join([f'{k}={v}' for k, v in self.__dict__.items()])
        return f'{self.__class__.__name__}({attr_repr})'


In [None]:
# On GPU
xyz = torch.rand(10**6, 50, 3).cuda()

torch.cuda.synchronize()
start = time()
batch_pca(xyz)
torch.cuda.synchronize()
print(f'PCA: {time() - start:0.3f}s')

PCA: 54.819s


In [None]:
# On CPU
xyz = torch.rand(10**6, 50, 3)

start = time()
batch_pca(xyz)
print(f'PCA: {time() - start:0.3f}s')

PCA: 1.981s


In [156]:
from torch_geometric.data import Data

# On CPU
xyz = torch.rand(10**6, 50, 3)

start = time()
eigenvalues, eigenvectors = batch_pca(xyz)
print(f'PCA CPU: {time() - start:0.3f}s')

# On CPU
start = time()
d = Data(x=xyz, eigenvalues=eigenvalues, eigenvectors=eigenvectors)
d = EigenFeatures(normal=True, linearity=True, planarity=True, scattering=True, verticality=True, temperature=None)(d)
print(f'Geometric Features: {time() - start:0.3f}s')

PCA CPU: 2.021s
Geometric Features: 0.012s


In [143]:
from torch_geometric.data import Data

# On CPU
xyz = torch.rand(10**6, 50, 3)

start = time()
eigenvalues, eigenvectors = batch_pca(xyz)
print(f'PCA CPU: {time() - start:0.3f}s')

# On GPU
torch.cuda.synchronize()
start = time()
d = Data(x=xyz.cuda(), eigenvalues=eigenvalues.cuda(), eigenvectors=eigenvectors.cuda())
d = EigenFeatures(normal=True, linearity=True, planarity=True, scattering=True, verticality=True, temperature=None)(d)
torch.cuda.synchronize()
print(f'Geometric Features: {time() - start:0.3f}s')

PCA CPU: 2.037s
Geometric Features: 0.137s


In [14]:
from torch_geometric.data import Data

# On CPU
start = time()
eigenvalues, eigenvectors = batch_pca(data.pos[data.neighbors])
data.eigenvalues = eigenvalues
data.eigenvectors = eigenvectors
print(f'PCA CPU: {time() - start:0.3f}s')

# On CPU
start = time()
data = EigenFeatures(normal=True, linearity=True, planarity=True, scattering=True, verticality=True, temperature=None)(data)
print(f'Geometric Features: {time() - start:0.3f}s')

PCA CPU: 5.287s
Geometric Features: 0.093s


Surprisingly, torch's CPU implementation is faster both for computing PCA and geometric features is faster on CPU overall !

### Final

In [5]:
from superpoint_transformer.transforms import compute_pointfeatures

data = data.cpu()
torch.cuda.synchronize()

start = time()
data = compute_pointfeatures(data, pos=True, radius=5, rgb=True, linearity=True, planarity=True, scattering=True, verticality=True, normal=False, length=False, surface=False, volume=False)
print(f'Geometric features: {time() - start:0.3f}s')

Geometric features: 2.147s


# Point adjacency graph computation
This graph is based on the nearest neighbor graph computed for geometric features. However, although features may require 30-50 neighbors to produce good partition, the adjacency graph benefits from using fewer neighbors (eg 10 in the paper).

In [6]:
from superpoint_transformer.transforms import compute_ajacency_graph

k_adjacency = 10
lambda_edge_weight = 1

start = time()
data = compute_ajacency_graph(data, k_adjacency, lambda_edge_weight)
print(f'Adjacency graph: {time() - start:0.3f}s')

Adjacency graph: 0.369s


# Partition

In [None]:
from superpoint_transformer.transforms import compute_partition

torch.cuda.synchronize()
start = time()

# Parallel cut-pursuit
nag = compute_partition(data, 0.5, cutoff=10, verbose=True, iterations=10)
# nag = compute_partition(data, 0.5, cutoff=10, verbose=True, iterations=5)

torch.cuda.synchronize()
print(f'Partition num_nodes={data.num_nodes}, num_edges={data.num_edges}: {time() - start:0.3f}s')

# Superpoint graph computation
In the original SPG implementation, the SP graph would be computed based on the pointwise Delaunay triangulation graph. This is super inefficient. Instead, we will compute the Delaunay triangulation on the superpoint level, which sould be much faster. However, to account for large and long-shaped superpoints, we will not work with the SP centroids only (Delaunay triangulation would not capture all adjacent SPs), but on random/farthest point samplings inside the SPs (as function of SP area/volume/number of points).  

In [12]:
from scipy.spatial import Delaunay
import itertools
from torch_scatter import scatter_mean, scatter_std, scatter_min, scatter_sum
import superpoint_transformer.partition.utils.libpoint_utils as point_utils
from torch_geometric.nn.pool.consecutive import consecutive_cluster


def sample_clusters(data, high=32, low=1, pointers=False):
    """Compute point indices for sampling points inside clusters saved
    in a CSR format.
    """
    #TODO: rename this function: sample_points_in_nag, sample_point_index, ...
    #TODO: operate on NAG and careful with indices !
    #TODO: must be able to sample on subset of sub too (for edge sampling)
    #TODO: could optionally sample from i into j with i > j > 0...

    # Compute the number of points that will be sampled from each
    # cluster
    if high > 0:
        # k * tanh(x / k) is bounded by k, is ~x for low x and starts
        # saturating at x~k
        n_samples = (high * torch.tanh(
            data.sub.sizes / high)).floor().long()
    else:
        # Fallback to sqrt sampling
        n_samples = data.sub.sizes.sqrt().round().long()

    # Make sure each cluster is sampled at least 'low' times
    n_samples = n_samples.clamp(min=low)

    # Sample values in [0, 1], these will be used to compute
    # corresponding integer indices for vectorized sampling of the
    # points directly from the CSR-format cluster-to-point data
    samples = torch.rand(n_samples.sum())

    # Convert the [0, 1] samples to integer indices in [0, n_samples]
    # depending on the corresponding cluster sampling size. The 0.9999
    # factor is a security to handle the case where torch.rand is to
    # close to 1.0, which would yield incorrect sampling coordinates
    n_bins = n_samples.repeat_interleave(n_samples)
    samples = (samples * 0.9999 * n_bins).floor().long()

    # As they are now, sampling indices are expressed in [0, sub_size].
    # If we want to sample directly from the cluster-to-point CSR data,
    # we need to express those coordinates in the aggregated CSR values,
    # that is to say we need to apply the cumulative cluster sizes as
    # offsets to the indices. This is typically what is stored in
    # CSRData.pointers
    offsets = data.sub.pointers[:-1].repeat_interleave(n_samples)
    samples = samples + offsets

    # Now we can get the sampled point indices
    idx_samples = data.sub.points[samples]

    # Return here if sampling pointers are not required
    if not pointers:
        return idx_samples

    # Compute the pointers
    ptr_samples = torch.cat([torch.LongTensor([0]), n_samples.cumsum(dim=0)])

    return idx_samples, ptr_samples.contiguous()


In [None]:
def _compute_cluster_graph(
        i_level, nag, high_node=32, high_edge=64, low=5):
    # TODO: WARNING the cluster geometric features will only work if we
    #  enforced a cutoff on the minimum superpoint size ! Make sure you
    #  enforce this

    # TODO: define recursive sampling for super(n)point features
    # TODO: define recursive edge for super(n)edge features
    # TODO: return all eigenvectors from the C++ geometric features, for
    #  superborder features computation
    # TODO: other superedge ideas to better describe how 2 clusters
    # relate and the geometry of their border (S=source, T=target):
    # - avg distance S/T points in border to centroid S/T (how far
    #   is the border from the cluster center)
    # - angle of mean S->T direction wrt S/T principal components (is
    #   the border along the long of short side of objects ?)
    # - PCA of points in S/T cloud (is it linear border or surfacic
    #   border ?)
    # - mean dist of S->T along S/T normal (offset along the objects
    #   normals, eg offsets between steps)

    assert isinstance(nag, NAG)
    assert i_level > 0

    # Recover the i_level Data object we will be working on
    data = nag[i_level]

    # Aggregate some point attributes into the clusters. This is not
    # performed dynamically since not all attributes can be aggregated
    # (eg 'neighbors', 'distances', 'edge_index', 'edge_attr'...)
    data_sub = nag[i_level - 1]

    if 'pos' in data_sub.keys:
        data.pos = scatter_mean(
            data_sub.pos.cuda(), data_sub.super_index.cuda(), dim=0).cpu()
        torch.cuda.empty_cache()

    if 'rgb' in data_sub.keys:
        data.rgb = scatter_mean(
            data_sub.rgb.cuda(), data_sub.super_index.cuda(), dim=0).cpu()
        torch.cuda.empty_cache()

    if 'y' in data_sub.keys:
        assert data_sub.y.dim() == 2, \
            "Expected Data.y to hold `(num_nodes, num_classes)` " \
            "histograms, not single labels"
        data.y = scatter_sum(
            data_sub.y.cuda(), data_sub.super_index.cuda(), dim=0).cpu()
        torch.cuda.empty_cache()

    # Sample points among the clusters. These will be used to compute
    # cluster geometric features as well as cluster adjacency graph and
    # edge features
    idx_samples, ptr_samples = sample_clusters(
        data, high=high_node, low=low, pointers=True)

    # Compute cluster geometric features
    xyz = nag[0].pos[idx_samples].cpu().numpy()
    nn = np.arange(idx_samples.shape[0]).astype('uint32')  # !!!! IMPORTANT CAREFUL WITH UINT32 = 4 BILLION points MAXIMUM !!!!
    nn_ptr = ptr_samples.cpu().numpy().astype('uint32')  # !!!! IMPORTANT CAREFUL WITH UINT32 = 4 BILLION points MAXIMUM !!!!

    # Heuristic to avoid issues when a cluster sampling is such that
    # it produces singular covariance matrix (eg the sampling only
    # contains the same point repeated multiple times)
    xyz = xyz + torch.rand(xyz.shape).numpy() * 1e-5

    # C++ geometric features computation on CPU
    f = point_utils.compute_geometric_features(xyz, nn, nn_ptr, False)
    f = torch.from_numpy(f.astype('float32'))

    # Recover length, surface and volume
    data.length = f[:, 7].to(data.pos.device)
    data.surface = f[:, 8].to(data.pos.device)
    data.volume = f[:, 9].to(data.pos.device)
    data.normal = f[:, 4:7].view(-1, 3).to(data.pos.device)

    # Sample points among the clusters. These will be used to compute
    # cluster adjacency graph and edge features. Note we sample more
    # generously here than for cluster features, because we need to
    # capture fine-grained adjacency
    idx_samples, ptr_samples = sample_clusters(
        data, high=high_edge, low=low, pointers=True)

    # Delaunay triangulation on the sampled points
    tri = Delaunay(nag[0].pos[idx_samples].numpy())

    # Concatenate all edges of the triangulation. For now, we do not
    # worry about directed/undirected graphs to mitigate memory and
    # compute
    pairs = torch.LongTensor(list(itertools.combinations(range(4), 2)))
    edges = torch.from_numpy(np.hstack([
        np.vstack((tri.vertices[:, i], tri.vertices[:, j]))
        for i, j in pairs]).T).long()

    # Now we are only interested in the edges connecting two different
    # clusters and not in the intra-cluster connections. So we first
    # identify the edges of interest. This step requires having access
    # to the whole NAG, since we need to convert level-0 point indices
    # into their corresponding level-i superpoint indices
    idx_point_source = idx_samples[edges[:, 0]]
    idx_point_target = idx_samples[edges[:, 1]]
    idx_source = idx_point_source
    idx_target = idx_point_target
    for i in range(i_level):
        idx_source = nag[i].super_index[idx_source]
        idx_target = nag[i].super_index[idx_target]
    inter_cluster = torch.where(idx_source != idx_target)[0]

    # Now only consider the edges of interest (ie inter-cluster edges)
    idx_point_source = idx_point_source[inter_cluster]
    idx_point_target = idx_point_target[inter_cluster]
    idx_source = idx_source[inter_cluster]
    idx_target = idx_target[inter_cluster]

    # Direction are the pointwise source->target vectors, based on which
    # we will compute superedge descriptors. So far we are manipulating
    # inter-cluster edges, but their may be multiple of those for a
    # given source-target pair. Next, we want to aggregate those into
    # "superegdes" and compute corresponding features (designated with
    # 'se_')
    direction = nag[0].pos[idx_point_target] - nag[0].pos[idx_point_source]
    dist = torch.linalg.norm(direction, dim=1)

    # Create unique and consecutive inter-cluster edge identifiers for
    # torch_scatter operations. We use 'se' to designate 'superedge' (ie
    # an edge between two clusters)
    idx_se = idx_source + data.num_nodes * idx_target
    idx_se, perm = consecutive_cluster(idx_se)
    idx_se_source = idx_source[perm]
    idx_se_target = idx_target[perm]
    se = torch.vstack((idx_se_source, idx_se_target))

    # We can now use torch_scatter operations to compute superedge
    # features
    se_direction = scatter_mean(direction.cuda(), idx_se.cuda(), dim=0).cpu()
    se_dist = scatter_mean(dist.cuda(), idx_se.cuda(), dim=0).cpu()
    se_min_dist = scatter_min(dist.cuda(), idx_se.cuda(), dim=0)[0].cpu()
    se_std_dist = scatter_std(dist.cuda(), idx_se.cuda(), dim=0).cpu()

    se_centroid_direction = data.pos[se[1]] - data.pos[se[0]]
    se_centroid_dist = torch.linalg.norm(se_centroid_direction, dim=1)

    se_normal_source = data.normal[se[0]]
    se_normal_target = data.normal[se[1]]
    se_normal_angle = (se_normal_source * se_normal_target).sum(dim=1)
    se_angle_source = (se_direction * se_normal_source).sum(dim=1)
    se_angle_target = (se_direction * se_normal_target).sum(dim=1)

    se_length_ratio = data.length[se[0]] / (data.length[se[1]] + 1e-6)
    se_surface_ratio = data.surface[se[0]] / (data.surface[se[1]] + 1e-6)
    se_volume_ratio = data.volume[se[0]] / (data.volume[se[1]] + 1e-6)
    se_size_ratio = data.sub_size[se[0]] / (data.sub_size[se[1]] + 1e-6)

    # The superedges we have created so far are oriented. We need to
    # create the edges and corresponding features for the Target->Source
    # direction now
    se = torch.cat((se, se.roll(1, 1)))

    se_feat = [
        torch.cat((se_dist, se_dist)),
        torch.cat((se_min_dist, se_min_dist)),
        torch.cat((se_std_dist, se_std_dist)),
        torch.cat((se_centroid_dist, se_centroid_dist)),
        torch.cat((se_normal_angle, se_normal_angle)),
        torch.cat((se_angle_source, se_angle_target)),
        torch.cat((se_angle_target, se_angle_source)),
        torch.cat((se_length_ratio, 1 / (se_length_ratio + 1e-6))),
        torch.cat((se_surface_ratio, 1 / (se_surface_ratio + 1e-6))),
        torch.cat((se_volume_ratio, 1 / (se_volume_ratio + 1e-6))),
        torch.cat((se_size_ratio, 1 / (se_size_ratio + 1e-6)))]

    # Aggregate all edge features in a single tensor
    se_feat = torch.vstack(se_feat).T

    # Save superedges and superedge features in the Data object
    data.edge_index = se
    data.edge_attr = se_feat

    # Restore the i_level Data object, if need be
    nag._list[i_level] = data

    return nag


def compute_cluster_graph(nag, high_node=32, high_edge=64, low=5):
    assert isinstance(nag, NAG)
    for i_level in range(1, nag.num_levels):
        nag = _compute_cluster_graph(
            i_level, nag, high_node=high_node, high_edge=high_edge, low=low)
    return nag

In [None]:
# from superpoint_transformer.transforms import compute_cluster_graph

start = time()
compute_cluster_graph(nag, high_node=32, high_edge=64, low=5)
print(f'SP Graph computation: {time() - start:0.3f}s')

# Visualization

In [10]:
# import os
# import torch

# temp_dir = DATA_ROOT + '/datasets/kitti360/shared/temp' 
# os.makedirs(temp_dir, exist_ok=True) 

# torch.save((data, data_c), os.path.join(temp_dir, 'preliminaries.pt'))

In [51]:
from torch_geometric.data import Data
from torch_geometric.transforms import FixedPoints
from superpoint_transformer.transforms import GridSampling3D
import os.path as osp
import plotly
import plotly.graph_objects as go
import numpy as np
import torch


# PALETTE = np.array(plotly.colors.qualitative.Plotly)
# PALETTE = np.array(plotly.colors.qualitative.Dark24)
PALETTE = np.array(plotly.colors.qualitative.Light24)


def rgb_to_plotly_rgb(rgb):
    """Convert torch.Tensor of float RGB values in [0, 1] to
    plotly-friendly RGB format.
    """
    assert isinstance(rgb, torch.Tensor) and rgb.max() <= 1.0 and rgb.dim() <= 2

    if rgb.dim() == 1:
        rgb = rgb.unsqueeze(0)

    return [f"rgb{tuple(x)}" for x in (rgb * 255).int().numpy()]


def hex_to_tensor(h):
    h = h.lstrip('#')
    rgb = tuple(int(h[i:i + 2], 16) for i in (0, 2, 4))
    return torch.Tensor(rgb) / 255


def feats_to_rgb(feats, normalize=False):
    """Convert features of the format M x N with N>=1 to an M x 3
    tensor with values in [0, 1 for RGB visualization].
    """
    is_normalized = False

    if feats.dim() == 1:
        feats = feats.unsqueeze(1)
    elif feats.dim() > 2:
        raise NotImplementedError

    if feats.shape[1] == 3:
        color = feats

    elif feats.shape[1] == 1:
        # If only 1 feature is found convert to a 3-channel
        # repetition for grayscale visualization.
        color = feats.repeat_interleave(3, 1)

    elif feats.shape[1] == 2:
        # If 2 features are found, add an extra channel.
        color = torch.cat([feats, torch.ones(feats.shape[0], 1)], 1)

    elif feats.shape[1] > 3:
        # If more than 3 features or more are found, project
        # features to a 3-dimensional space using N-simplex PCA
        # Heuristics for clamping
        #   - most features live in [0, 1]
        #   - most n-simplex PCA features live in [-0.5, 0.6]
        color = identity_PCA(feats, dim=3)
        color = (torch.clamp(color, -0.5, 0.6) + 0.5) / 1.1
        is_normalized = True

    if normalize and not is_normalized:
        # Unit-normalize the features in a hypercube of shared scale
        # for nicer visualizations
        if color.max() != color.min():
            color = color - color.min(dim=0).values.view(1, -1)
        color = color / (color.max(dim=0).values.view(1, -1) + 1e-6)

    return color


def identity_PCA(x, dim=3):
    """Reduce dimension of x based on PCA on the union of the n-simplex.
    This is a way of reducing the dimension of x while treating all
    input dimensions with the same importance, independently of the
    input distribution in x.
    """
    assert x.dim() == 2, f"Expected x.dim()=2 but got x.dim()={x.dim()} instead"

    # Create z the union of the N-simplex
    input_dim = x.shape[1]
    z = torch.eye(input_dim)

    # PCA on z
    z_offset = z.mean(axis=0)
    z_centered = z - z_offset
    cov_matrix = z_centered.T.mm(z_centered) / len(z_centered)
    _, eigenvectors = torch.linalg.eigh(cov_matrix)

    # Apply the PCA on x
    x_reduced = (x - z_offset).mm(eigenvectors[:, -dim:])

    return x_reduced


def visualize_3d(
        data, data_c, figsize=800, width=None, height=None, class_names=None,
        class_colors=None, class_opacities=None, voxel=0.1, max_points=100000,
        pointsize=5, error_color=None, show_superpoint_number=False, **kwargs):
    
    # 3D visualization modes
    modes = {'name': [], 'key': [], 'num_traces': []}

    # Make copies of the data to be modified in this scope
    data = data.clone()
    data_c = data_c.clone()
    
    # Check whether a partition is available
    has_partition = data_c is not None and data.p2c is not None

    # Subsample to limit the drawing time
    data.edge_index = None
    data.edge_attr = None
    data = GridSampling3D(voxel, mode='last')(data)
    if data.num_nodes > max_points:
        data = FixedPoints(
            max_points, replace=False, allow_duplicates=False)(data)
        
    # Round to the cm for cleaner hover info
    data.pos = (data.pos * 100).round() / 100
    data_c.pos = (data_c.pos * 100).round() / 100

    # Class colors initialization
    if class_colors is not None and not isinstance(class_colors[0], str):
        class_colors = [f"rgb{tuple(x)}" for x in class_colors]
    else:
        class_colors = None

    # Prepare figure
    width = width if width and height else figsize
    height = height if width and height else int(figsize / 2)
    margin = int(0.02 * min(width, height))
    layout = go.Layout(
        width=width,
        height=height,
        scene=dict(aspectmode='data', ),  # preserve aspect ratio
        margin=dict(l=margin, r=margin, b=margin, t=margin),
        uirevision=True)
    fig = go.Figure(layout=layout)
    initialized_visibility = False
    
    # Draw a trace for RGB 3D point cloud
    if getattr(data, 'rgb', None) is not None:
        fig.add_trace(
            go.Scatter3d(
                name='RGB',
                x=data.pos[:, 0],
                y=data.pos[:, 1],
                z=data.pos[:, 2],
                mode='markers',
                marker=dict(
                    size=pointsize,
                    color=rgb_to_plotly_rgb(data.rgb), ),
                hoverinfo='x+y+z',
                showlegend=False,
                visible=not initialized_visibility, ))
        modes['name'].append('RGB')
        modes['key'].append('rgb')
        modes['num_traces'].append(1)
        initialized_visibility = True

    # Draw a trace for labeled 3D point cloud
    if getattr(data, 'y', None) is not None:
        
        # If labels are expressed as histograms, keep the most frequent
        # one
        if data.y.dim() == 2:
            data.y = data.y.argmax(1)
        
        y = data.y.numpy()
        n_y_traces = 0

        for label in np.unique(y):
            indices = np.where(y == label)[0]

            fig.add_trace(
                go.Scatter3d(
                    name=class_names[label] if class_names else f"Class {label}",
                    opacity=class_opacities[label] if class_opacities else 1.0,
                    x=data.pos[indices, 0],
                    y=data.pos[indices, 1],
                    z=data.pos[indices, 2],
                    mode='markers',
                    marker=dict(
                        size=pointsize,
                        color=class_colors[label] if class_colors else None, ),
                    visible=not initialized_visibility, ))
            n_y_traces += 1  # keep track of the number of traces

        modes['name'].append('Labels')
        modes['key'].append('y')
        modes['num_traces'].append(n_y_traces)
        initialized_visibility = True

    # Draw a trace for predicted labels 3D point cloud
    if getattr(data, 'pred', None) is not None:
        pred = data.pred.numpy()
        n_pred_traces = 0

        for label in np.unique(pred):
            indices = np.where(pred == label)[0]

            fig.add_trace(
                go.Scatter3d(
                    name=class_names[label] if class_names else f"Class {label}",
                    opacity=class_opacities[label] if class_opacities else 1.0,
                    x=data.pos[indices, 0],
                    y=data.pos[indices, 1],
                    z=data.pos[indices, 2],
                    mode='markers',
                    marker=dict(
                        size=pointsize,
                        color=class_colors[label] if class_colors else None, ),
                    visible=not initialized_visibility, ))
            n_pred_traces += 1  # keep track of the number of traces

        modes['name'].append('Predictions')
        modes['key'].append('pred')
        modes['num_traces'].append(n_pred_traces)
        initialized_visibility = True
    
    # Draw a trace for position-colored 3D point cloud
    # radius = torch.norm(data.pos - data.pos.mean(dim=0), dim=1).max()
    # data.pos_rgb = (data.pos - data.pos.mean(dim=0)) / (2 * radius) + 0.5
    mini = data.pos.min(dim=0).values
    maxi = data.pos.max(dim=0).values
    data.pos_rgb = (data.pos - mini) / (maxi - mini + 1e-6)
    fig.add_trace(
        go.Scatter3d(
            name='Position RGB',
            x=data.pos[:, 0],
            y=data.pos[:, 1],
            z=data.pos[:, 2],
            mode='markers',
            marker=dict(
                size=pointsize,
                color=rgb_to_plotly_rgb(data.pos_rgb), ),
            hoverinfo='x+y+z',
            showlegend=False,
            visible=not initialized_visibility, ))
    modes['name'].append('Position RGB')
    modes['key'].append('position_rgb')
    modes['num_traces'].append(1)
    initialized_visibility = True
    
    # Draw a trace for the partition
    if has_partition:
        fig.add_trace(
            go.Scatter3d(
                name='Superpoints',
                x=data.pos[:, 0],
                y=data.pos[:, 1],
                z=data.pos[:, 2],
                mode='markers',
                marker=dict(
                    size=pointsize,
                    color=PALETTE[data.p2c % len(PALETTE)], ),
                hoverinfo='x+y+z',
                showlegend=False,
                visible=not initialized_visibility, ))
        modes['name'].append('Superpoints')
        modes['key'].append('superpoints')
        modes['num_traces'].append(1)
        initialized_visibility = True
    
    # Draw a trace for 3D point cloud features
    if getattr(data, 'x', None) is not None:
        # Recover the features and convert them to an RGB format for
        # visualization.
        data.feat_3d = feats_to_rgb(data.x, normalize=True)
        fig.add_trace(
            go.Scatter3d(
                name='Features 3D',
                x=data.pos[:, 0],
                y=data.pos[:, 1],
                z=data.pos[:, 2],
                mode='markers',
                marker=dict(
                    size=pointsize,
                    color=rgb_to_plotly_rgb(data.feat_3d), ),
                hoverinfo='x+y+z',
                showlegend=False,
                visible=not initialized_visibility, ))
        modes['name'].append('Features 3D')
        modes['key'].append('x')
        modes['num_traces'].append(1)
        initialized_visibility = True
    
    # Add a trace for prediction errors
    has_error = getattr(data, 'y', None) is not None \
                and getattr(data, 'pred', None) is not None
    if has_error:
        indices = np.where(data.pred.numpy() != data.y.numpy())[0]
        error_color = f"rgb{tuple(error_color)}" \
            if error_color is not None else 'rgb(255, 0, 0)'
        fig.add_trace(
            go.Scatter3d(
                name='Errors',
                opacity=1.0,
                x=data.pos[indices, 0],
                y=data.pos[indices, 1],
                z=data.pos[indices, 2],
                mode='markers',
                marker=dict(
                    size=pointsize,
                    color=error_color, ),
                showlegend=False,
                visible=False, ))
        modes['name'].append('Errors')
        modes['key'].append('error')
        modes['num_traces'].append(1)
        
    # Draw cluster centroid positions
    if has_partition:
        idx_sp = np.arange(data_c.num_nodes)
        sp_traces = []
        sp_traces.append(len(fig.data))
        fig.add_trace(
            go.Scatter3d(
                name=f"Superpoint centroids",
                x=data_c.pos[:, 0],
                y=data_c.pos[:, 1],
                z=data_c.pos[:, 2],
                mode='markers+text',
                marker=dict(
                    symbol='diamond',
                    line_width=2,
                    size=pointsize + 2,
                    color=PALETTE[idx_sp % len(PALETTE)], ),
                text=[f"<b>{i}</b>" for i in idx_sp] if show_superpoint_number else '',
                textposition="bottom center",
                textfont=dict(size=16),
                hoverinfo='x+y+z+name',
                showlegend=False,
                visible=True, ))
    
    # Traces visibility for interactive point cloud coloring
    def trace_visibility(mode):
        visibilities = np.array([d.visible for d in fig.data], dtype='bool')

        # Traces visibility for interactive point cloud coloring
        i_mode = modes['key'].index(mode)
        a = sum(modes['num_traces'][:i_mode])
        b = sum(modes['num_traces'][:i_mode + 1])
        n_traces = sum(modes['num_traces'])

        visibilities[:n_traces] = False
        visibilities[a:b] = True

        return [{"visible": visibilities.tolist()}]

    # Create the buttons that will serve for toggling trace visibility
    updatemenus = [
        dict(
            buttons=[dict(label=name, method='update', args=trace_visibility(key))
                     for name, key in zip(modes['name'], modes['key']) if key != 'error'],
            pad={'r': 10, 't': 10},
            showactive=True,
            type='dropdown',
            direction='right',
            xanchor='left',
            x=0.02,
            yanchor='top',
            y=1.02, ),
    ]
    
    if has_error:
        updatemenus.append(
            dict(
                buttons=[dict(
                    method='restyle',
                    label='Error',
                    visible=True,
                    args=[{'visible': True, },
                          [sum(modes['num_traces'][:modes['key'].index('error')])]],
                    args2=[{'visible': False, },
                           [sum(modes['num_traces'][:modes['key'].index('error')])]], )],
                pad={'r': 10, 't': 10},
                showactive=False,
                type='buttons',
                xanchor='left',
                x=1.02,
                yanchor='top',
                y=1.02, ),
        )
    fig.update_layout(updatemenus=updatemenus)

    # Place the legend on the left
    fig.update_layout(
        legend=dict(
            yanchor="middle",
            y=0.5,
            xanchor="right",
            x=0.99))

    # Hide all axes and no background
    fig.update_layout(
        scene=dict(
            xaxis_title='',
            yaxis_title='',
            zaxis_title='',
            xaxis=dict(
                autorange=True,
                showgrid=False,
                ticks='',
                showticklabels=False,
                backgroundcolor="rgba(0, 0, 0, 0)"
            ),
            yaxis=dict(
                autorange=True,
                showgrid=False,
                ticks='',
                showticklabels=False,
                backgroundcolor="rgba(0, 0, 0, 0)"
            ),
            zaxis=dict(
                autorange=True,
                showgrid=False,
                ticks='',
                showticklabels=False,
                backgroundcolor="rgba(0, 0, 0, 0)"
            )
        )
    )

    output = {'figure': fig, 'data': data}
    
    if has_partition:
        output['sp_traces'] = sp_traces

    return output


def show(
        data, data_c, path=None, title=None, no_output=True, **kwargs):
    """
    """
    assert isinstance(data, Data)
    assert isinstance(data_c, Data)

    # Sanitize title and path
    if title is None:
        title = "Multimodal data"
    if path is not None:
        if osp.isdir(path):
            path = osp.join(path, f"{title}.html")
        else:
            path = osp.splitext(path)[0] + '.html'
        fig_html = f'<h1 style="text-align: center;">{title}</h1>'

    # Draw a figure for 3D data visualization
    out_3d = visualize_3d(data, data_c, **kwargs)
    if no_output:
        if path is None:
            out_3d['figure'].show(config={'displayModeBar': False})
        else:
            fig_html += figure_html(out_3d['figure'])

    if path is not None:
        with open(path, "w") as f:
            f.write(fig_html)

    if not no_output:
        return out_3d

    return

In [12]:
import os
import torch

temp_dir = DATA_ROOT + '/datasets/kitti360/shared/temp' 
data, data_c = torch.load(os.path.join(temp_dir, 'preliminaries.pt'))

In [None]:
from superpoint_transformer.datasets.kitti360 import CLASS_NAMES, CLASS_COLORS
    
show(
    data, data_c, figsize=1000, width=None, height=None, class_names=CLASS_NAMES,
    class_colors=CLASS_COLORS, class_opacities=None, voxel=0.1, max_points=500000,
    pointsize=3, error_color=None)

In [78]:
data.p2c.numel()

2478764

questions
- give courses for post doc ETH ?
- paper says Z normalized over point cloud... but code seems to just take z*2 ?

In [83]:
from torch_geometric.nn.pool.consecutive import consecutive_cluster

a = torch.arange(10).repeat_interleave(torch.arange(10)) * 10
print(a)

b, perm = consecutive_cluster(a)
print(b)
print(perm)

tensor([10, 20, 20, 30, 30, 30, 40, 40, 40, 40, 50, 50, 50, 50, 50, 60, 60, 60,
        60, 60, 60, 70, 70, 70, 70, 70, 70, 70, 80, 80, 80, 80, 80, 80, 80, 80,
        90, 90, 90, 90, 90, 90, 90, 90, 90])
tensor([0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6,
        6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor([ 0,  2,  5,  9, 14, 20, 27, 35, 44])


In [84]:
b.shape == a.shape

True

In [87]:
a[perm], b[perm]

(tensor([10, 20, 30, 40, 50, 60, 70, 80, 90]),
 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]))

In [90]:
values = torch.arange(10)
idx = torch.LongTensor([5, 2, 7, 1])

# Get view-level indices for images to keep
view_idx = torch.where((values[..., None] == idx).any(-1))[0]

# Index the values
values = values[view_idx]

# Update the image indices. To do so, create a tensor of indices
# idx_gen so that the desired output can be computed with simple
# indexation idx_gen[images]. This avoids using map() or
# numpy.vectorize alternatives.
idx_gen = torch.full(
    (idx.max() + 1,), -1, dtype=torch.int64, )
idx_gen = idx_gen.scatter_(
    0, idx, torch.arange(idx.shape[0], ))
idx_gen[values]  # values[0] holds image indices

tensor([3, 1, 0, 2])

In [101]:
v, i = torch.randint(3, (10,)).sort()

In [118]:
torch.arange(10).equal(torch.arange(10))

True

In [126]:
torch.arange(10).flip(0).sort().values.equal()

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [127]:
data.num_edges

24787640

In [131]:
a = np.arange(10)
b = a[5:]
b *= 10
print(a, b)

[ 0  1  2  3  4 50 60 70 80 90] [50 60 70 80 90]
