In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import panel as pn
import numpy as np
import pyvista as pv
pv.set_plot_theme("document")
import glob
from matplotlib.colors import ListedColormap
from omegaconf import OmegaConf
from torch_geometric.data import Data
import random
import torch

pn.extension('vtk')
os.system('/usr/bin/Xvfb :99 -screen 0 1024x768x24 &')
os.environ['DISPLAY'] = ':99'
os.environ['PYVISTA_OFF_SCREEN'] = 'True'
os.environ['PYVISTA_USE_PANEL'] = 'True'

DIR = os.path.dirname(os.getcwd())
sys.path.append(DIR)

from torch_points3d.datasets.panoptic.scannet import ScannetDataset, ScannetPanoptic
from torch_points3d.datasets.segmentation.scannet import Scannet, SCANNET_COLOR_MAP
from torch_points3d.datasets.segmentation import IGNORE_LABEL

In [None]:
VIZ_REL_DIR = "outputs/2020-07-13/17-57-38/viz"
VIZ_DIR = os.path.join(DIR, VIZ_REL_DIR)
all_viz_data = glob.glob(os.path.join(VIZ_DIR,'*.pt'))

In [None]:
def get_sample(data, sample_idx):
    sample_mask = data.batch == sample_idx
    out_data = Data(pos=data.pos[sample_mask], instance_labels=data.instance_labels[sample_mask], semantic_pred=data.semantic_pred[sample_mask], y = data.y[sample_mask], vote=data.vote[sample_mask])
    pos_instances = []
    vote_instances = []
    for i,cl in enumerate(data.clusters):
        instance_sample = data.batch[cl[0]]
        if instance_sample != sample_idx:
            continue
        if data.cluster_type[i] == 0:
            pos_instances.append(data.pos[cl])
        else:
            vote_instances.append(data.pos[cl])
    out_data.pos_instances = pos_instances
    out_data.vote_instances =vote_instances
    return out_data

In [None]:
def buil_cmap():
    mapping = np.linspace(IGNORE_LABEL, len(Scannet.VALID_CLASS_IDS)+1, 256)
    newcolors = np.zeros((256, 3))
    for i, raw_label in enumerate(Scannet.VALID_CLASS_IDS):
        newcolors[mapping >= i-0.5] = np.asarray(Scannet.SCANNET_COLOR_MAP[raw_label]) / 255.
    return ListedColormap(newcolors)
cmap = buil_cmap()

In [None]:
file_idx = -1
sample_idx = 0

In [None]:
data = torch.load(all_viz_data[file_idx])
data

In [None]:
data = torch.load(all_viz_data[file_idx])
def load_random_sample(event):
    i = np.random.randint(0, data.batch.max()+1)
    sample = get_sample(data, i)
    
    pl1 = pv.Plotter(notebook=True)
    pl2 = pv.Plotter(notebook=True)
    pl3 = pv.Plotter(notebook=True)
    pl4 = pv.Plotter(notebook=True)
    pl5 = pv.Plotter(notebook=True)
    
    # Color by points with a label
    mask = sample.instance_labels == 0
    pl1.add_points(sample.pos[mask].numpy(), color="gray", opacity=0.4) 
    pl2.add_points(sample.pos[mask].numpy(), color="gray", opacity=0.4) 
    pl4.add_points(sample.pos[mask].numpy(), color="gray", opacity=0.4) 
    pl5.add_points(sample.pos[mask].numpy(), color="gray", opacity=0.4) 
    
    num_instances = sample.instance_labels.max()
    for i in range(1,num_instances+1):
        instance_mask = sample.instance_labels == i
        point_cloud = pv.PolyData(sample.pos[instance_mask].numpy())
        labels = sample.y[instance_mask].numpy()
        point_cloud['label'] = labels
        pl1.add_points(point_cloud,cmap=cmap, clim=[-1, len(Scannet.VALID_CLASS_IDS)+1])
    
    proposed_instances = sample.pos_instances
    for instance in proposed_instances:
        color = [random.random(),random.random(),random.random()]
        pl2.add_points(instance.numpy(), color=color,) 
    
    proposed_instances = sample.vote_instances
    for instance in proposed_instances:
        color = [random.random(),random.random(),random.random()]
        pl4.add_points(instance.numpy(), color=color,) 
    
    point_cloud = pv.PolyData(sample.pos.numpy())
    point_cloud['label'] = sample.semantic_pred.numpy()
    pl3.add_points(point_cloud,cmap=cmap, clim=[-1, len(Scannet.VALID_CLASS_IDS)+1]) 
    
    instance_mask = sample.instance_labels != 0
    vote_centre = sample.pos[instance_mask] + sample.vote[instance_mask]
    pl5.add_points(vote_centre.numpy(), color='red')

    pan1.object = pl1.ren_win
    pan2.object = pl2.ren_win
    pan3.object = pl3.ren_win
    pan4.object = pl4.ren_win
    pan5.object = pl5.ren_win

In [None]:
pl1 = pv.Plotter(notebook=True)
pl2 = pv.Plotter(notebook=True)
pl3 = pv.Plotter(notebook=True)
pl4 = pv.Plotter(notebook=True)
pl5 = pv.Plotter(notebook=True)
pan1 = pn.panel(pl1.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)
pan2 = pn.panel(pl2.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)
pan3 = pn.panel(pl3.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)
pan4 = pn.panel(pl4.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)
pan5 = pn.panel(pl5.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)
button = pn.widgets.Button(name='Load new model', button_type='primary')
button.on_click(load_random_sample)
pn.Row(
    pn.Column('## Scannet vizualise',button),
    pn.Column(pn.Column(pan1,'Ground truth'),pn.Column(pan2, 'Pos based predictions')),
    pn.Column(pn.Column(pan3,'Semantic predictions'), pn.Column(pan4,'Vote clusters')),
    pn.Column(pn.Column(pan5,'Votes'))
)

## Exploration of the clustering

In [None]:
from torch_points_kernels import region_grow

In [None]:
data = torch.load(all_viz_data[file_idx])
i = np.random.randint(0, data.batch.max()+1)
sample = get_sample(data, i)

In [None]:
i

In [None]:
torch.unique(sample.instance_labels)


In [None]:
clusters = region_grow(sample.pos.cuda() + sample.vote.cuda(),sample.semantic_pred.cuda(),torch.zeros(sample.pos.shape[0]).long().cuda(),ignore_labels=[-1,0,1], radius=0.075,min_cluster_size=32,nsample=16)
data.clusters = [cl.cpu() for cl in clusters]

In [None]:
len(clusters)

In [None]:
num_points = 0
for cl in data.clusters:
    num_points += len(cl)
num_points

In [None]:
mask = sample.instance_labels == 0
pl = pv.Plotter(notebook=True)
pl.add_points(sample.pos[mask].numpy(), color="gray")
for cl in clusters:
    color = [random.random(),random.random(),random.random()]
    pl.add_points(sample.pos[cl].numpy(), color=color,) 
pn.panel(pl.ren_win,  aspect_ratio=1,orientation_widget=True,)

In [None]:
clusters = region_grow(data.pos.cpu(),data.semantic_pred.cpu(),data.batch.cpu(),ignore_labels=[-1,0,1],
                radius=0.075,min_cluster_size=32,nsample=16)

In [None]:
mask = data.instance_labels == 0
pl = pv.Plotter(notebook=True)
pl.add_points(data.pos[mask].numpy(), color="gray")
for cl in clusters:
    color = [random.random(),random.random(),random.random()]
    pl.add_points(data.pos[cl].numpy(), color=color,) 
pn.panel(pl.ren_win,  aspect_ratio=1,orientation_widget=True,)

In [None]:
torch.unique(data.semantic_pred)