In [1]:
import nmslib

In [3]:
import pandas as pd
import numpy as np
import torch as pt
import multiprocessing
#from bps import bps
from sembps.bps import bps
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
import h5py

In [4]:
MAIN_PATH = os.path.join('aptbps-code')
LOGS_PATH = os.path.join(MAIN_PATH, 'logs')
DATA_PATH = os.path.join(MAIN_PATH, 'data')

train_path = os.path.join(DATA_PATH, 'train')
hdf5_path = os.path.join(DATA_PATH, 'hdf5')
no_unlabeled_h5_path = os.path.join(DATA_PATH, 'no_unlabeled_hdf5')
hdf5_train_path = os.path.join(hdf5_path, 'train')
encoded_hdf5_path = os.path.join(DATA_PATH, 'tree_encoded_hdf5')


# All the clouds in the training dataset
train_files = [
    "bildstein_station1_xyz_intensity_rgb",
    "bildstein_station3_xyz_intensity_rgb",
    "bildstein_station5_xyz_intensity_rgb",
    "domfountain_station1_xyz_intensity_rgb",
    "domfountain_station2_xyz_intensity_rgb",
    "domfountain_station3_xyz_intensity_rgb",
    "neugasse_station1_xyz_intensity_rgb",
    "sg27_station1_intensity_rgb",
    "sg27_station2_intensity_rgb",
    "sg27_station4_intensity_rgb",
    "sg27_station5_intensity_rgb",
    "sg27_station9_intensity_rgb",
    "sg28_station4_intensity_rgb",
    "untermaederbrunnen_station1_xyz_intensity_rgb",
    "untermaederbrunnen_station3_xyz_intensity_rgb",
]

In [5]:
n_orig_points = 2048
n_bps_points = 512
n_dims = 3
radius = 1.5
random_seed = 13

In [6]:
# batch of 100 point clouds to convert
x = np.random.normal(size=[100, n_orig_points, 3])

# optional point cloud normalization to fit a unit sphere
x = bps.normalize(x)

In [7]:
from timeit import default_timer as timer
import time
from sklearn.neighbors import NearestNeighbors

In [8]:
n_clouds, n_points, n_dims = x.shape

basis_set = bps.generate_random_basis(n_bps_points, n_dims=n_dims, radius=radius, random_seed=random_seed)

n_bps_points = basis_set.shape[0]

x_bps = np.zeros([n_clouds, n_bps_points])
        
fid_lst = range(0, x.shape[0])

idx_bps = np.zeros([n_clouds, n_bps_points])

# demo (various cases with different tunings)

In [31]:
# create a random matrix to index
data = np.random.randn(1000, 3).astype(np.float32)

# initialize a new index, using a HNSW index on Cosine Similarity
index = nmslib.init(method='hnsw', space='cosinesimil')
index.addDataPointBatch(x[0])
index.createIndex({'post': 2}, print_progress=True)

# query for the nearest neighbours of the first datapoint
ids, distances = index.knnQuery(x[0], k=1)

# get all nearest neighbours for all the datapoint
# using a pool of 4 threads to compute
neighbours = index.knnQueryBatch(basis_set, k=1, num_threads=4)

In [55]:
start = timer()

for fid in fid_lst:
    index = nmslib.init(method='hnsw', space='l2')
    index.addDataPointBatch(x[fid])
    index.createIndex({'post': 2}, print_progress=True)

    neighbours = index.knnQueryBatch(basis_set, k=1, num_threads=4)
    
    for i in range (0, n_bps_points):
        idx_bps[fid][i] = neighbours[i][0][0]
        x_bps[fid][i] = neighbours[i][1][0]
    
    
    
end = timer()
print(end-start)

4.299919313052669


In [56]:
idx_bps

array([[1858., 1218.,   78., ..., 1625., 1697.,  368.],
       [1308.,  586.,  324., ..., 1631.,  586., 1878.],
       [ 343., 1079.,  263., ..., 1557., 1079., 1101.],
       ...,
       [ 733.,  162., 1322., ..., 1667.,  554., 1175.],
       [1066., 1622., 1894., ..., 1034., 1622.,  498.],
       [ 228., 1097., 1332., ...,  794., 1097.,  541.]])

In [50]:
start = timer()
for fid in fid_lst:
            nbrs = NearestNeighbors(n_neighbors=1, leaf_size=16, algorithm='kd_tree').fit(x[fid])
            fid_dist, npts_ix = nbrs.kneighbors(basis_set)
            x_bps[fid] = fid_dist.squeeze()
            idx_bps[fid] = npts_ix.squeeze()
end = timer()
print(end-start)

0.2178341569378972


In [51]:
idx_bps

array([[1858., 1218.,   78., ..., 1625., 1697.,  368.],
       [1308.,  586.,  324., ..., 1631.,  586., 1878.],
       [ 343., 1079.,  263., ..., 1557., 1079., 1101.],
       ...,
       [ 733.,  162., 1322., ..., 1667.,  554., 1175.],
       [1066., 1622., 1894., ..., 1034., 1622.,  498.],
       [ 228., 1097., 1332., ...,  794., 1097.,  541.]])