In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

# Add the project's files to the python path
# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script
file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook
sys.path.append(file_path)

import hydra
from src.utils import init_config
import torch
from src.transforms import *
from src.utils.widgets import *
from src.data import *

## Select your device, experiment, split, and pretrained model

In [None]:
device_widget = make_device_widget()
task_widget, expe_widget = make_experiment_widgets()
split_widget = make_split_widget()
ckpt_widget = make_checkpoint_file_search_widget()

In [None]:
# Summarizing selected task, experiment, split, and checkpoint
print(f"You chose:")
print(f"  - device={device_widget.value}")
print(f"  - task={task_widget.value}")
print(f"  - split={split_widget.value}")
print(f"  - experiment={expe_widget.value}")
print(f"  - ckpt={ckpt_widget.value}")

## Parsing the config files
Hydra and OmegaConf are used to parse the `yaml` config files.

❗Make sure you selected a **ckpt file relevant to your experiment** in the previous section. 
You can use our pretrained models for this, or your own checkpoints if you have already trained a model.

In [None]:
# Parse the configs using hydra
cfg = init_config(overrides=[
    f"experiment={task_widget.value}/{expe_widget.value}",
    f"ckpt_path={ckpt_widget.value}",
    f"datamodule.load_full_res_idx={True}"  # only when you need full-resolution predictions 
])

## Datamodule and model instantiation

In [None]:
# Instantiate the datamodule
datamodule = hydra.utils.instantiate(cfg.datamodule)
datamodule.prepare_data()
datamodule.setup()

# Pick among train, val, and test datasets. It is important to note that
# the train dataset produces augmented spherical samples of large 
# scenes, while the val and test dataset load entire tiles at once
if split_widget.value == 'train':
    dataset = datamodule.train_dataset
elif split_widget.value == 'val':
    dataset = datamodule.val_dataset
elif split_widget.value == 'test':
    dataset = datamodule.test_dataset
else:
    raise ValueError(f"Unknown split '{split_widget.value}'")

# Print a summary of the datasets' classes
dataset.print_classes()

# Instantiate the model
model = hydra.utils.instantiate(cfg.model)

# Load pretrained weights from a checkpoint file
if ckpt_widget.value is not None:
    model = model._load_from_checkpoint(cfg.ckpt_path)

# Move model to selected device
model = model.eval().to(device_widget.value)

## Hierarchical partition loading and inference
SPT can process very large scenes at once.
Depending on the dataset stage you use (train, val, or test), the inference will be run on a whole million-point tile or on a spherical sampling of it.

In [None]:
# For the sake of visualization, we require that NAGAddKeysTo does not 
# remove input Data attributes after moving them to Data.x, so we may 
# visualize them
for t in dataset.on_device_transform.transforms:
    if isinstance(t, NAGAddKeysTo):
        t.delete_after = False

# Load the first dataset item. This will return the hierarchical 
# partition of an entire tile, as a NAG object 
nag = dataset[0]

# Apply on-device transforms on the NAG object. For the train dataset, 
# this will select a spherical sample of the larger tile and apply some
# data augmentations. For the validation and test datasets, this will
# prepare an entire tile for inference
nag = dataset.on_device_transform(nag.to(device_widget.value))

# Inference, returns a task-specific ouput object carrying predictions
with torch.no_grad():
    output = model(nag)

# Compute the level-0 (voxel-wise) semantic segmentation predictions 
# based on the predictions on level-1 superpoints and save those for 
# visualization in the level-0 Data under the 'semantic_pred' attribute
nag[0].semantic_pred = output.voxel_semantic_pred(super_index=nag[0].super_index)

# Similarly, compute the level-0 panoptic segmentation predictions, if 
# relevant
if task_widget.value == 'panoptic':
    vox_y, vox_index, vox_obj_pred = output.voxel_panoptic_pred(super_index=nag[0].super_index)
    nag[0].obj_pred = vox_obj_pred

By design, our model only needs to produce predictions for the $P_1$ superpoints for training. 
This conveniently saves compute and memory at training and evaluation time.

At inference time however, we often **need the predictions on the $P_0$ voxel level or on the full-resolution input point cloud**.
To this end, we provide helper functions to recover voxel-wise and full-resolution predictions.
In the previous cell, for instance, `voxel_semantic_pred()` and `voxel_panoptic_pred()` were used for computing voxel-wise predictions and attaching them to our `NAG` object.

In the following cell, we show how to efficiently recover the full-resolution predictions with `full_res_semantic_pred()` and `full_res_panoptic_pred()` (requires `datamodule.load_full_res_idx=True` in the config).

In [None]:
# Compute the full-resolution semantic prediction. These labels are ordered 
# with respect to the full-resolution data points in the corresponding raw 
# input file. Note that we do not provide the pipeline for recovering the 
# corresponding full-resolution positions, colors, etc. 
raw_semseg_y = output.full_res_semantic_pred(
    super_index_level0_to_level1=nag[0].super_index,
    sub_level0_to_raw=nag[0].sub)

# Similarly, we can compute the full-resolution panoptic prediction. 
# The returned outputs are (in order) the predicted semantic prediction, the
# predicted instance index, and the InstancData object holding this information 
# under another format
if task_widget.value == 'panoptic':
    raw_pano_y, raw_index, raw_obj_pred = output.full_res_panoptic_pred(
        super_index_level0_to_level1=nag[0].super_index, 
        sub_level0_to_raw=nag[0].sub)

## Visualizing an entire tile
SPT can process very large scenes at once. Let's visualize the output.

In [None]:
# Visualize the hierarchical partition
nag.show( 
    class_names=dataset.class_names,
    class_colors=dataset.class_colors,
    stuff_classes=dataset.stuff_classes,
    num_classes=dataset.num_classes,
    max_points=100000
)

However, for memory reasons, the visualization cannot display all points. Let's have a look at a smaller area.

## Selecting a portion of the hierarchical partition
The NAG structure can be subselected using `nag.select()`.

This function expects an `int` specifying the partition level from which we should select, along with an index or a mask in the form or a `list`, `numpy.ndarray`, `torch.Tensor`, or `slice`.
This index/mask describes which nodes to select at the specified level.

The output NAG will only contain children, parents and edges of the selected nodes.

Specifying `radius` and `center` for the `show()` function will make use of this `nag.select()` method internally.

In [None]:
# Predefined radius and center locations for each dataset
# Feel free to modify these values
center = nag[0].pos.mean(dim=0).view(1, -1)
if 'dales' in expe_widget.value:
    radius = 10
elif 'kitti360' in expe_widget.value:
    radius = 10
elif 'scannet' in expe_widget.value:
    radius = 10
elif 's3dis' in expe_widget.value:
    radius = 3
else:
    radius = 3

In [None]:
# Visualize the sample
nag.show(
    radius=radius,
    center=center,
    class_names=dataset.class_names,
    class_colors=dataset.class_colors,
    stuff_classes=dataset.stuff_classes,
    num_classes=dataset.num_classes,
    max_points=100000
)

## Visualizing random samples centered on a class of interest
You may be interested in seeing random samples of a given class. To this end, you can simply use the `BaseDataset.show_examples()` method.

In [None]:
# Display some samples of a dataset, centered on a label of interest. 
# You may specify the class of interest as an int label or by the actual 
# class name 
datamodule.train_dataset.show_examples(7, radius=radius, max_examples=5)

## Visualizing the superpoint graphs
Let's have a closer look to visualize the graph connecting superpoints by setting `centroids=True` and `h_edge=True`.

In [None]:
# Visualize the sample
nag.show(
    radius=radius,
    center=center,
    class_names=dataset.class_names,
    class_colors=dataset.class_colors,
    stuff_classes=dataset.stuff_classes,
    num_classes=dataset.num_classes,
    max_points=100000, 
    centroids=True, 
    h_edge=True, 
    h_edge_width=2
)

## Side-by-side visualization mode
By setting `gap` to a chosen 3D offset, we can visualize all partition levels at once. Besides, setting `v_edge=True` will display the vertical edges connecting superpoints with their children.

In [None]:
# Visualize the sample
nag.show(
    figsize=1000,
    radius=radius,
    center=center,
    class_names=dataset.class_names,
    class_colors=dataset.class_colors,
    stuff_classes=dataset.stuff_classes,
    num_classes=dataset.num_classes,
    max_points=100000, 
    centroids=True, 
    v_edge=True, 
    v_edge_width=2, 
    gap=[0, 0, 4]
)

## Exporting your visualization to HTML
You can export your interactive visualization to HTML. 
You can then share your visualization, to be opened on any web browser with internet connection.

To export a visualization, simply specify a `path` to which the file should be saved.
Additionally, you may set a `title` to be displayed in your HTML.

In [None]:
# Visualize the sample
nag.show(
    figsize=1600,
    radius=radius,
    center=center,
    class_names=dataset.class_names,
    class_colors=dataset.class_colors,
    stuff_classes=dataset.stuff_classes,
    num_classes=dataset.num_classes,
    max_points=100000,
    title="My Interactive Visualization Partition", 
    path="my_interactive_visualization.html"
)

## Going further with visualization
See the commented code in `src.visualization` for more visualization options.