# Scannet panoptic vizualisation

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import importlib

DIR = os.path.dirname(os.getcwd())
torch_points3d = os.path.join(DIR, "torch_points3d")
assert os.path.exists(torch_points3d)

MODULE_PATH = os.path.join(torch_points3d, "__init__.py")
MODULE_NAME = "torch_points3d"
spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import panel as pn
import numpy as np
import pyvista as pv
pv.set_plot_theme("document")


import glob
from matplotlib.colors import ListedColormap
from omegaconf import OmegaConf
import random

pn.extension('vtk')
os.system('/usr/bin/Xvfb :99 -screen 0 1024x768x24 &')
os.environ['DISPLAY'] = ':99'
os.environ['PYVISTA_OFF_SCREEN'] = 'True'
os.environ['PYVISTA_USE_PANEL'] = 'True'

DIR = os.path.dirname(os.getcwd())
sys.path.append(DIR)

from torch_points3d.datasets.panoptic.scannet import ScannetDataset, ScannetPanoptic
from torch_points3d.datasets.segmentation.scannet import Scannet, SCANNET_COLOR_MAP
from torch_points3d.datasets.segmentation import IGNORE_LABEL

## Load Scannet dataset

In [None]:
dataset_options = OmegaConf.load(os.path.join(DIR,'conf/data/panoptic/scannet-sparse.yaml'))

In [None]:
dataset_options.data.dataroot = os.path.join(DIR,"data")
dataset = ScannetDataset(dataset_options.data)
dataset.train_dataset.transform = None
print(dataset)

## Visualise the data

In [None]:
d = dataset.train_dataset[0]

In [None]:
d

In [None]:
def buil_cmap():
    mapping = np.linspace(IGNORE_LABEL, len(Scannet.VALID_CLASS_IDS)+1, 256)
    newcolors = np.zeros((256, 3))
    for i, raw_label in enumerate(Scannet.VALID_CLASS_IDS):
        newcolors[mapping >= i-0.5] = np.asarray(Scannet.SCANNET_COLOR_MAP[raw_label]) / 255.
    return ListedColormap(newcolors)
cmap = buil_cmap()

In [None]:
def load_random_data(event):
    i = np.random.randint(0, len(dataset.train_dataset))
    sample = dataset.train_dataset[i]
    pl = pv.Plotter(notebook=True)
    pl2 = pv.Plotter(notebook=True)
    
    # Color by points with a label
    mask = sample.instance_mask
    point_cloud = pv.PolyData(sample.pos[mask==False].numpy())
    labels = sample.y[mask==False].numpy()
    point_cloud['label'] = labels
    pl.add_points(point_cloud,cmap=cmap, clim=[-1, len(Scannet.VALID_CLASS_IDS)+1], point_size=1)
    pl2.add_points(point_cloud,color='gray', point_size=1)
    
    # Color by points with a label
    for i in range(1,sample.num_instances.item()+1):
        instance_mask = sample.instance_labels == i
        point_cloud = pv.PolyData(sample.pos[instance_mask].numpy())
        labels = sample.y[instance_mask].numpy()
        point_cloud['label'] = labels
        pl.add_points(point_cloud,cmap=cmap, clim=[-1, len(Scannet.VALID_CLASS_IDS)+1])
        
    centre = sample.pos[sample.instance_mask] + sample.vote_label[sample.instance_mask]
    pl2.add_points(centre.numpy(), color='red')

    pan.object = pl.ren_win
    pan2.object = pl2.ren_win

In [None]:
pl = pv.Plotter(notebook=True)
pl2 = pv.Plotter(notebook=True)
pan = pn.panel(pl.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)
pan2 = pn.panel(pl2.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)
button = pn.widgets.Button(name='Load new model', button_type='primary')
button.on_click(load_random_data)
pn.Row(
    pn.Column('## Scannet vizualiser',button),
    pn.Column(pan,'INSTANCE LABELS'),
    pn.Column(pan2, 'CENTER LABELS')
)