
# Scannet vizualisation

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import importlib

DIR = os.path.dirname(os.getcwd())
torch_points3d = os.path.join(DIR, "torch_points3d")
assert os.path.exists(torch_points3d)

MODULE_PATH = os.path.join(torch_points3d, "__init__.py")
MODULE_NAME = "torch_points3d"
spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)

In [None]:
from torch_points3d import datasets
from torch_points3d.datasets import object_detection

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import panel as pn
import numpy as np
import pyvista as pv
pv.set_plot_theme("document")
import glob
from matplotlib.colors import ListedColormap
from omegaconf import OmegaConf
import random

pn.extension('vtk')
os.system('/usr/bin/Xvfb :99 -screen 0 1024x768x24 &')
os.environ['DISPLAY'] = ':99'
os.environ['PYVISTA_OFF_SCREEN'] = 'True'
os.environ['PYVISTA_USE_PANEL'] = 'True'

DIR = os.path.dirname(os.getcwd())
sys.path.append(DIR)

from torch_points3d.datasets.object_detection.scannet import ScannetDataset, ScannetObjectDetection
from torch_points3d.datasets.segmentation.scannet import Scannet, SCANNET_COLOR_MAP
from torch_points3d.datasets.segmentation import IGNORE_LABEL

## Load Scannet dataset

In [None]:
dataset_options = OmegaConf.load(os.path.join(DIR,'conf/data/object_detection/scannet.yaml'))

In [None]:
dataset_options.data.dataroot = os.path.join(DIR,"data")
dataset = ScannetDataset(dataset_options.data)
dataset.train_dataset.transform = None
print(dataset)

## Visualise the data

In [None]:
d = dataset.train_dataset[0]

In [None]:
d.size_class_label

In [None]:
d.sem_cls_label 

In [None]:
def load_random_data(event):
    i = np.random.randint(0, len(dataset.train_dataset))
    sample = dataset.train_dataset[i]
    pl = pv.Plotter(notebook=True)
    
    # Color by points with a label
    mask = sample.vote_label_mask
    pl.add_points(sample.pos[mask == True].numpy(), color="blue") 
    pl.add_points(sample.pos[mask == False].numpy(), color="grey", opacity=0.75) 
    
    # Color by points with a label
    centres = sample.center_label[sample.box_label_mask].numpy()
#     pl.add_points(centres,color="red", point_size=10.)
    
    # Bounding boxes
    labels = sample.sem_cls_label[sample.box_label_mask]
    box_size = sample.size_residual_label[sample.box_label_mask].numpy() + dataset.train_dataset.MEAN_SIZE_ARR[labels]
    for i, centre in enumerate(centres):
        box = pv.Box((centre[0] - box_size[i][0] / 2, centre[0] + box_size[i][0] / 2,
                     centre[1] - box_size[i][1] / 2, centre[1] + box_size[i][1] / 2,
                     centre[2] - box_size[i][2] / 2, centre[2] + box_size[i][2] / 2))
        label = dataset.train_dataset.NYU40IDS[sample.sem_cls_label[i].item()]
        color = np.asarray(SCANNET_COLOR_MAP[label]) / 255.
        pl.add_mesh(box, color=color, show_edges=True, opacity=0.5)

    pan.object = pl.ren_win

In [None]:
pl = pv.Plotter(notebook=True)
pan = pn.panel(pl.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)

In [None]:
button = pn.widgets.Button(name='Load new model', button_type='primary')
button.on_click(load_random_data)

In [None]:
dashboard = pn.Row(
    pn.Column('## Scannet vizualise',button),
    pan
)

In [None]:
dashboard