In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pyvista as pv
import panel as pn
import os
import os.path as osp
from omegaconf import OmegaConf
pv.set_plot_theme("document")

pn.extension('vtk')
os.system('/usr/bin/Xvfb :99 -screen 0 1024x768x24 &')
os.environ['DISPLAY'] = ':99'
os.environ['PYVISTA_OFF_SCREEN'] = 'True'
os.environ['PYVISTA_USE_PANEL'] = 'True'

In [None]:
import torch
import numpy as np
from plyfile import PlyData
import copy

In [None]:
import sys
sys.path.append("..")

In [None]:
from torch_points3d.datasets.segmentation.scannet import Scannet
from torch_points3d.datasets.oneshot_detection.scannet import ScannetOneShotDetection
from torch_points3d.datasets.segmentation import IGNORE_LABEL
from torch_points3d.core.data_transform import GridSampling3D, AddOnes, AddFeatByKey
from torch_geometric.transforms import Compose
from torch_geometric.data import Data, Batch
from torch_points3d.utils.registration import get_matches, fast_global_registration
from torch_points3d.applications.pretrained_api import PretainedRegistry

# Aligning objects in scenes

This notebooks explores ways to align objects in a scene with the idea that this could be used for object detection. Let's start by loading Scannet.

In [None]:
DIR = os.path.dirname(os.getcwd())
ONE_SHOT_CLASS=4
dataroot = os.path.join(DIR,"data","scannet-oneshot")
transform = Compose([GridSampling3D(mode='last', size=0.02, quantize_coords=True), AddOnes(), AddFeatByKey(add_to_x=True, feat_name="ones")])
dataset = ScannetOneShotDetection(dataroot,transform=transform)
print(dataset)

## Some utilities

In [None]:
def plot(clouds, together=False, colors=[]):
    viewers = []
    for i,cloud in enumerate(clouds):
        if not together or len(viewers) == 0:
            v = pv.Plotter(notebook=True)
            viewers.append(v)
        if len(colors) > i:
            color = colors[i]
        else:
            color = [0.9, 0.7, 0.1]
        v.add_points(cloud.pos.numpy(), color=color)
            
    pan = [pn.panel(v.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,) for v in viewers]
    if together:
        return pan[0]
    else:
        return pn.Row(*pan)
        

In [None]:
def get_instances(data, label_idx):
    instances = []
    unique_instances = torch.unique(data.instance_labels)[-1] + 1
    for i in torch.unique(data.instance_labels):
        if i == 0:
            continue
        instance_mask = data.instance_labels == i
        label = data.y[instance_mask][-1]
        if label == label_idx:
            instances.append(Data(pos = data.pos[instance_mask], x = data.x[instance_mask], coords = data.coords[instance_mask]))
    return instances    

## Getting the data

In [None]:
d = dataset[15]
beds = get_instances(d, 3)
beds

In [None]:
plot([d,beds[0]], together=True, colors = [[0.9, 0.7, 0.1],[0.1, 0.7, 0.9]])

## Register

In [None]:
# This will log some errors, don't worry it's all good!
model = PretainedRegistry.from_pretrained("minkowski-registration-3dmatch").cuda()

In [None]:
from torch_points3d.utils.geometry import euler_angles_to_rotation_matrix
import random

d15 =dataset[15]
beds = get_instances(d15,3)
bed15 = beds[0]

d0= dataset[0]
beds = get_instances(d0,3)
bed0= beds[0]

In [None]:
def compute_features(data):
    batch = Batch.from_data_list([data])
    with torch.no_grad():
        model.set_input(batch, "cuda")
        output = model.forward()
    return output

def register(data, obj):
    data_feat = compute_features(data)
    obj_feats = compute_features(obj)
    matches = get_matches(data_feat, obj_feats, sym=True)
    print("Number of matches = %i" %matches.shape[0])
    T_est = fast_global_registration(obj.pos[matches[:, 1]],data.pos[matches[:, 0]])
    transformed_obj = copy.deepcopy(obj)
    transformed_obj.pos= obj.pos @ T_est[:3, :3].T + T_est[:3, 3]
    return transformed_obj

In [None]:
transformed = register(d0, bed15)

In [None]:
plot([d0,transformed], together=True, colors = [[0.9, 0.7, 0.1],[0.1, 0.7, 0.9]])

## Visualise matches

In [None]:
from torch_geometric.nn import knn

In [None]:
def plot_matches(data,obj, max_lines=100):
    d_feats = compute_features(data)
    obj_feats = compute_features(obj)
    matches = get_matches(d_feats, obj_feats, sym=True)
    if matches.shape[0]>max_lines:
        perm = torch.randperm(matches.shape[0])
        idx = perm[:max_lines]
        matches = matches[idx,:]
    
    v = pv.Plotter(notebook=True)
    v.add_points(obj.pos.cpu().numpy())
    moved_scan = data.pos.cpu().numpy() + np.asarray([5,0,0])
    v.add_points(moved_scan)
    for i in range(matches.shape[0]):
        lines = []
        lines.append(moved_scan[matches[i,0]])
        lines.append(obj.pos[matches[i,1]].numpy())
        v.add_lines(np.asarray(lines), width=5, color="green")
    return pn.panel(v.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)
    

In [None]:
plot_matches(d15, bed15)

## Model and accuracy

In [None]:
from torch_points3d.datasets.object_detection.box_data import BoxData
from torch_points3d.metrics.oneshottracker import OneShotObjectTracker

In [None]:
from torch_points3d.utils.registration import teaser_pp_registration

In [None]:
ONE_SHOT_CLASS = 5

In [None]:
class RegistrationResult:
    def __init__(self,obj, class_label):
        if obj == None:
            self.box = None
            return
        
        self.transformed_obj = obj
        min_pos, max_pos = torch.min(obj.pos,0)[0],torch.max(obj.pos,0)[0]
        xi,yi,zi = min_pos
        xm, ym, zm = max_pos
        corners = torch.tensor([
            [xi,yi,zi], [ xm,yi, zi],[xm, ym, zi],[xi, ym, zi],
            [xi,yi,zm], [ xm,yi, zi],[xm, ym, zm],[xi, ym, zm],
        ])
        self.box = BoxData(class_label, corners, 1)
        
    def get_boxes(self):
        if self.box is not None:
            return [[self.box]]
        else:
            return [[]]
        

In [None]:
class RegistrationModel(torch.nn.Module):
    def __init__(self,class_label, min_inliers = 10):
        super().__init__()
        self._min_inliers = min_inliers
        self.class_label = class_label
        self._model =  PretainedRegistry.from_pretrained("minkowski-registration-3dmatch").cuda()
    
    def compute_features(self,data):
        batch = Batch.from_data_list([data])
        with torch.no_grad():
            self._model.set_input(batch, "cuda")
            output = self._model.forward()
        return output

    def forward(self, data,  one_instance):
        data_feat = self.compute_features(data)
        obj_feats = self.compute_features(one_instance)
        matches = get_matches(data_feat, obj_feats, sym=True)

        # T_est = fast_global_registration(one_instance.pos[matches[:, 1]],data.pos[matches[:, 0]])
        T_est, inliers = teaser_pp_registration(one_instance.pos[matches[:, 1]],data.pos[matches[:, 0]])
        if len(inliers) > self._min_inliers:
            transformed_obj = copy.deepcopy(one_instance)
            transformed_obj.pos= one_instance.pos @ T_est[:3, :3].T + T_est[:3, 3]
            self.output = RegistrationResult(transformed_obj, self.class_label)
        else:
            self.output = RegistrationResult(None, self.class_label)
        
    def get_output(self):
        return self.output
    
    def get_current_losses(self):
        return {}

In [None]:
model = RegistrationModel(dataset.NYU40ID2CLASS[ONE_SHOT_CLASS],min_inliers = 20)

In [None]:
ONE_SHOT_CLASS = 5
d0 = dataset[5]
bed0 = get_instances(d0, ONE_SHOT_CLASS)[0]
# model(d0, bed15)
# replaced = model.get_output().transformed_obj
plot(get_instances(d0, ONE_SHOT_CLASS))

In [None]:
bed = transform(dataset.get_random_instance(ONE_SHOT_CLASS))

In [None]:
import tqdm
count = 0
max_iter = 15
tracker = OneShotObjectTracker(dataset)
with tqdm.notebook.tqdm(dataset) as bar:
    for d in bar:
    #     d = dataset[i]
        beds = get_instances(d, ONE_SHOT_CLASS)
        if not len(beds):
            continue
        model(d, bed0)
        tracker.track(model, Batch.from_data_list([d]))
        count += 1
        bar.set_postfix(**tracker.get_metrics())

In [None]:
tracker._tp

In [None]:
tracker.