# Aligning objects in scenes

This notebooks explores ways to align objects in a scene and extends that idea to one shot detection. The explored pipeline for one shot detection for a gievn object category C is as follow:

- pick a scene
- pick a random represent of C that is not in the scene
- compute a feature descriptor for all points in the scene and in the object. For that we use a pretrained network on a registration task on 3d match
- compute matches between the scene and the object using a symmetry constraint
- compute an estimate of scale + rotation + translation using a robust estimator such as RANSAC or Teaser++

The first part of the notebook shows how the registration algorithm works while the second part applies that to one shot detection. Initial results show that if the actual object is in the scene we have a very good chance of aligning it properly, however if the object is different then 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pyvista as pv
import panel as pn
import os
import os.path as osp
from omegaconf import OmegaConf
pv.set_plot_theme("document")

pn.extension('vtk')
os.system('/usr/bin/Xvfb :99 -screen 0 1024x768x24 &')
os.environ['DISPLAY'] = ':99'
os.environ['PYVISTA_OFF_SCREEN'] = 'True'
os.environ['PYVISTA_USE_PANEL'] = 'True'

In [None]:
import torch
import numpy as np
from plyfile import PlyData
import copy

In [None]:
import sys
sys.path.append("..")

In [None]:
from torch_points3d.datasets.segmentation.scannet import Scannet
from torch_points3d.datasets.oneshot_detection.scannet import ScannetOneShotDetection
from torch_points3d.datasets.segmentation import IGNORE_LABEL
from torch_points3d.core.data_transform import GridSampling3D, AddOnes, AddFeatByKey
from torch_geometric.transforms import Compose
from torch_geometric.data import Data, Batch
from torch_points3d.utils.registration import get_matches, fast_global_registration
from torch_points3d.applications.pretrained_api import PretainedRegistry

In [None]:
DIR = os.path.dirname(os.getcwd())
ONE_SHOT_CLASS=4
dataroot = os.path.join(DIR,"data","scannet-oneshot")
transform = Compose([GridSampling3D(mode='last', size=0.02, quantize_coords=True), AddOnes(), AddFeatByKey(add_to_x=True, feat_name="ones")])
dataset = ScannetOneShotDetection(dataroot,transform=transform)
print(dataset)

## Some utilities
Utilities for plotting and getting the data

In [None]:
def plot(clouds, together=False, colors=[]):
    viewers = []
    for i,cloud in enumerate(clouds):
        if not together or len(viewers) == 0:
            v = pv.Plotter(notebook=True)
            viewers.append(v)
        if len(colors) > i:
            color = colors[i]
        else:
            color = [0.9, 0.7, 0.1]
        v.add_points(cloud.pos.numpy(), color=color)
            
    pan = [pn.panel(v.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,) for v in viewers]
    if together:
        return pan[0]
    else:
        return pn.Row(*pan)
        

In [None]:
def get_instances(data, label_idx):
    instances = []
    unique_instances = torch.unique(data.instance_labels)[-1] + 1
    for i in torch.unique(data.instance_labels):
        if i == 0:
            continue
        instance_mask = data.instance_labels == i
        label = data.y[instance_mask][-1]
        if label == label_idx:
            instances.append(Data(pos = data.pos[instance_mask], x = data.x[instance_mask], coords = data.coords[instance_mask]))
    return instances    

## Placing an object in a scene
This first section explores the precision of placing an object in a scene when the object is present. This is the simplest one shot detection we can think of.

In [None]:
# We first load a pretrained model for registration
# This will log some errors, don't worry it's all good!
model = PretainedRegistry.from_pretrained("minkowski-registration-3dmatch").cuda()

In [None]:
# Load some data from teh dataset and extract a gievn instance of a given class. Here 4 is the bed class
d15 =dataset[15]
beds = get_instances(d15,4)
bed15 = beds[0]

In [None]:
def compute_features(data):
    # For a given data object, computes a feature vector for each point
    # using the pretrained registration model. Returns a tensor that contains per point features
    batch = Batch.from_data_list([data])
    with torch.no_grad():
        model.set_input(batch, "cuda")
        output = model.forward()
    return output

def register(data, obj):
    # Computes the transform that aligns obj in data
    # - compute features
    # - compute matches using the symetry constraint
    # - use fast-global-registration to compute rotation and translation
    data_feat = compute_features(data)
    obj_feats = compute_features(obj)
    matches = get_matches(data_feat, obj_feats, sym=True)
    print("Number of matches = %i" %matches.shape[0])
    T_est = fast_global_registration(obj.pos[matches[:, 1]],data.pos[matches[:, 0]])
    transformed_obj = copy.deepcopy(obj)
    transformed_obj.pos= obj.pos @ T_est[:3, :3].T + T_est[:3, 3]
    return transformed_obj

In [None]:
transformed = register(d15, bed15)

In [None]:
plot([d15,transformed], together=True, colors = [[0.9, 0.7, 0.1],[0.1, 0.7, 0.9]])

The method below can also be used to visualiase how things where matched:

In [None]:
def plot_matches(data,obj, max_lines=100):
    d_feats = compute_features(data)
    obj_feats = compute_features(obj)
    matches = get_matches(d_feats, obj_feats, sym=True)
    if matches.shape[0]>max_lines:
        perm = torch.randperm(matches.shape[0])
        idx = perm[:max_lines]
        matches = matches[idx,:]
    
    v = pv.Plotter(notebook=True)
    v.add_points(obj.pos.cpu().numpy())
    moved_scan = data.pos.cpu().numpy() + np.asarray([5,0,0])
    v.add_points(moved_scan)
    for i in range(matches.shape[0]):
        lines = []
        lines.append(moved_scan[matches[i,0]])
        lines.append(obj.pos[matches[i,1]].numpy())
        v.add_lines(np.asarray(lines), width=5, color="green")
    return pn.panel(v.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)
    

In [None]:
plot_matches(d15, bed15)

## Extension for one shot detection
Now that we have the basic building block in place let's explore the potential of using such method for one shot detection. As explained in the introduction, the proposed pipeline looks like that:

- pick a scene
- pick a random represent of C that is not in the scene
- compute a feature descriptor for all points in the scene and in the object. For that we use a pretrained network on a registration task on 3d match
- compute matches between the scene and the object using a symmetry constraint
- compute an estimate of scale + rotation + translation using a robust estimator such as RANSAC or Teaser++
- if there are more than N inliers that fit this transform then we consider that have a detection

The performance of our detector is tested by using a precision and recall metric, an object is considered as being detected when the overlap between the ground truth and predicted bounded boxes is higher than 25% in miou.


In [None]:
from torch_points3d.datasets.object_detection.box_data import BoxData
from torch_points3d.metrics.oneshottracker import OneShotObjectTracker

In [None]:
from torch_points3d.utils.registration import teaser_pp_registration

In [None]:
class RegistrationResult:
    """  This class is used to store results from the registration model, 
    it provides a `get_boxes` that is required by the tracker
    """
    def __init__(self,obj, class_label):
        if obj == None:
            self.box = None
            return
        
        self.transformed_obj = obj
        min_pos, max_pos = torch.min(obj.pos,0)[0],torch.max(obj.pos,0)[0]
        xi,yi,zi = min_pos
        xm, ym, zm = max_pos
        corners = torch.tensor([
            [xi,yi,zi], [ xm,yi, zi],[xm, ym, zi],[xi, ym, zi],
            [xi,yi,zm], [ xm,yi, zi],[xm, ym, zm],[xi, ym, zm],
        ])
        self.box = BoxData(class_label, corners, 1)
        
    def get_boxes(self):
        if self.box is not None:
            return [[self.box]]
        else:
            return [[]]
        

In [None]:
class RegistrationModel(torch.nn.Module):
    """ Wraps the functionalities explored in the previous section into an actual pytorch module. It is 
    not something that can be trained but provides the nice forward interface. results are exposed through the `get_output`
    method so that the tracker can work with that
    """
    def __init__(self,class_label, min_inliers = 10):
        super().__init__()
        self._min_inliers = min_inliers
        self.class_label = class_label
        self._model =  PretainedRegistry.from_pretrained("minkowski-registration-3dmatch").cuda()
    
    def compute_features(self,data):
        batch = Batch.from_data_list([data])
        with torch.no_grad():
            self._model.set_input(batch, "cuda")
            output = self._model.forward()
        return output

    def forward(self, data,  one_instance):
        data_feat = self.compute_features(data)
        obj_feats = self.compute_features(one_instance)
        matches = get_matches(data_feat, obj_feats, sym=True)

        # T_est = fast_global_registration(one_instance.pos[matches[:, 1]],data.pos[matches[:, 0]])
        T_est, inliers = teaser_pp_registration(one_instance.pos[matches[:, 1]],data.pos[matches[:, 0]])
        if len(inliers) > self._min_inliers:
            transformed_obj = copy.deepcopy(one_instance)
            transformed_obj.pos= one_instance.pos @ T_est[:3, :3].T + T_est[:3, 3]
            self.output = RegistrationResult(transformed_obj, self.class_label)
        else:
            self.output = RegistrationResult(None, self.class_label)
        
    def get_output(self):
        return self.output
    
    def get_current_losses(self):
        return {}

Let's initialise our inference pipeline by picking a category and one object within this category:

In [None]:
ONE_SHOT_CLASS = 4
d0 = dataset[0]
bed0 = get_instances(d0, ONE_SHOT_CLASS)[0]
model = RegistrationModel(dataset.NYU40ID2CLASS[ONE_SHOT_CLASS],min_inliers = 10)
plot([d0,bed0], together=True, colors = [[0.9, 0.7, 0.1],[0.1, 0.7, 0.9]])

We can now run the inference loop:

In [None]:
import tqdm
max_iter = 2
tracker = OneShotObjectTracker(dataset)
detected = []
with tqdm.notebook.tqdm(dataset) as bar:
    for i,d in enumerate(bar):
        if i > max_iter:
            break
        beds = get_instances(d, ONE_SHOT_CLASS)
        if not len(beds):
            continue
        model(d, bed0)
        out = model.get_output()
        if len(out.get_boxes()[0]):
            detected.append((d,out.transformed_obj))
        tracker.track(model, Batch.from_data_list([d]))
        count += 1
        bar.set_postfix(**tracker.get_metrics())

In [None]:
tracker._tp

In [None]:
tracker._ngt