In [None]:
import open3d as o3d
import torch
import numpy as np
import os
import os.path as osp
from plyfile import PlyData

In [None]:
import sys
sys.path.append("..")

<p align="center">
  <img width="40%" src="https://raw.githubusercontent.com/nicolas-chaulet/torch-points3d/master/docs/logo.png" />
</p>

# Registration Demo on 3DMatch

In this task, we will show a demonstration of registration on 3DMatch using a pretrained network from scratch.

First let's load some examples. 

In [None]:
# We read the data

def read_ply(path):
    with open(path, 'rb') as f:
        plydata = PlyData.read(f)
    vertex = plydata['vertex']
    return np.vstack((vertex['x'], vertex['y'], vertex['z'])).T
path_s = "data/3DMatch/redkitchen_000.ply"
path_t = "data/3DMatch/redkitchen_010.ply"

pcd_s = read_ply(path_s)
pcd_t = read_ply(path_t)


Now, we can put the point cloud in the class Batch, apply some transformation (transform data into sparse voxels, add ones). We can load the model too.

In [None]:
# data preprocessing import
from torch_points3d.core.data_transform import GridSampling3D, AddOnes, AddFeatByKey
from torch_geometric.transforms import Compose
from torch_geometric.data import Batch

# Model
from torch_points3d.applications.pretrained_api import PretainedRegistry

# post processing
from torch_points3d.utils.registration import get_matches, fast_global_registration

In [None]:
transform = Compose([GridSampling3D(mode='last', size=0.02, quantize_coords=True), AddOnes(), AddFeatByKey(add_to_x=True, feat_name="ones")])

In [None]:
data_s = transform(Batch(pos=torch.from_numpy(pcd_s).float(), batch=torch.zeros(pcd_s.shape[0]).long()))
data_t = transform(Batch(pos=torch.from_numpy(pcd_t).float(), batch=torch.zeros(pcd_t.shape[0]).long()))



model = PretainedRegistry.from_pretrained("minkowski-registration-3dmatch").cuda()

In [None]:
o3d_pcd_s = o3d.geometry.PointCloud()
o3d_pcd_s.points = o3d.utility.Vector3dVector(data_s.pos.cpu().numpy())
o3d_pcd_s.paint_uniform_color([0.9, 0.7, 0.1])

o3d_pcd_t = o3d.geometry.PointCloud()
o3d_pcd_t.points = o3d.utility.Vector3dVector(data_t.pos.cpu().numpy())
o3d_pcd_t.paint_uniform_color([0.1, 0.7, 0.9])
# visualizer = o3d.JVisualizer()
# visualizer.add_geometry(o3d_pcd_s)
# visualizer.add_geometry(o3d_pcd_t)
# visualizer.show()

In [None]:
with torch.no_grad():
    model.set_input(data_s, "cuda")
    output_s = model.forward()
    model.set_input(data_t, "cuda")
    output_t = model.forward()

 Now we have our feature let's match our features. We will select 5000 points

In [None]:
rand_s = torch.randint(0, len(output_s), (5000, ))
rand_t = torch.randint(0, len(output_t), (5000, ))

matches = get_matches(output_s[rand_s], output_t[rand_t])

T_est = fast_global_registration(data_s.pos[rand_s][matches[:, 0]], data_t.pos[rand_t][matches[:, 1]])

In [None]:
visualizer = o3d.JVisualizer()
visualizer.add_geometry(o3d_pcd_s)
visualizer.add_geometry(o3d_pcd_t)
visualizer.show()

In [None]:
visualizer = o3d.JVisualizer()
visualizer.add_geometry(o3d_pcd_s.transform(T_est.cpu().numpy()))
visualizer.add_geometry(o3d_pcd_t)
visualizer.show()

### Visualization
Let's try to visualize features to see what the network have learnt.


In [None]:
from sklearn.decomposition import PCA

def compute_color_from_features(list_feat):
    feats = np.vstack(list_feat)
    pca = PCA(n_components=3)
    pca.fit(feats)
    min_col = pca.transform(feats).min(axis=0)
    max_col = pca.transform(feats).max(axis=0)
    list_color = []
    for feat in list_feat:
        color = pca.transform(feat)
        color = (color - min_col) / (max_col - min_col)
        list_color.append(color)
    return list_color
list_color = compute_color_from_features([output_s.detach().cpu().numpy(), output_t.detach().cpu().numpy()])

In [None]:
o3d_pcd_s.colors = o3d.utility.Vector3dVector(list_color[0])
visualizer = o3d.JVisualizer()
visualizer.add_geometry(o3d_pcd_s)
visualizer.show()

In [None]:
o3d_pcd_t.colors = o3d.utility.Vector3dVector(list_color[1])
visualizer = o3d.JVisualizer()
visualizer.add_geometry(o3d_pcd_t)
visualizer.show()