In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

# Add the project's files to the python path
# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script
file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook
sys.path.append(file_path)

# Necessary for advanced config parsing with hydra and omegaconf
from omegaconf import OmegaConf
OmegaConf.register_new_resolver("eval", eval)

import hydra
from src.utils import init_config
import torch
from src.visualization import show
from src.datasets.dales import CLASS_NAMES, CLASS_COLORS
from src.datasets.dales import DALES_NUM_CLASSES as NUM_CLASSES
from src.transforms import *

## Parsing the config files
Hydra and OmegaConf are used to parse the `yaml` config files.

❗Make sure to **set the path to a relevant ckpt file**. 
You can use our pretrained models for this.

In [None]:
# Parse the configs using hydra
cfg = init_config(overrides=[
    "experiment=dales",
    "ckpt_path=path/to/your/checkpoint.ckpt"
])

## Datamodule and model instantiation

In [None]:
# Instantiate the datamodule
datamodule = hydra.utils.instantiate(cfg.datamodule)
datamodule.prepare_data()
datamodule.setup()

# Instantiate the model
model = hydra.utils.instantiate(cfg.model)

# Load pretrained weights from a checkpoint file
model = model.load_from_checkpoint(cfg.ckpt_path, net=model.net, criterion=None)
model.criterion = hydra.utils.instantiate(cfg.model).criterion
model = model.eval().cuda()

## Hierarchical partition loading and inference
SPT can process very large scenes at once.
Depending on the dataset stage you use (train, val, or test), the inference will be run on a whole million-point tile or on a spherical sampling of it.

In [None]:
# Pick among train, val, and test datasets. It is important to note that
# the train dataset produces augmented spherical samples of large 
# scenes, while the val and test dataset
# dataset = datamodule.train_dataset
dataset = datamodule.val_dataset
# dataset = datamodule.test_dataset

# For the sake of visualization, we require that NAGAddKeysTo does not 
# remove input Data attributes after moving them to Data.x, so we may 
# visualize them
for t in dataset.on_device_transform.transforms:
    if isinstance(t, NAGAddKeysTo):
        t.delete_after = False

# Load a dataset item. This will return the hierarchical partition of an 
# entire tile, within a NAG object 
nag = dataset[0]

# Apply on-device transforms on the NAG object. For the train dataset, 
# this will select a spherical sample of the larger tile and apply some
# data augmentations. For the validation and test datasets, this will
# prepare an entire tile for inference
nag = dataset.on_device_transform(nag.cuda())

# Inference
logits = model(nag)

# If the model outputs multi-stage predictions, we take the first one, 
# corresponding to level-1 predictions 
if model.multi_stage_loss:
    logits = logits[0]

# Compute the level-0 (pointwise) predictions based on the predictions
# on level-1 superpoints
l1_preds = torch.argmax(logits, dim=1).detach()
l0_preds = l1_preds[nag[0].super_index]

# Save predictions for visualization in the level-0 Data attributes 
nag[0].pred = l0_preds

## Visualizing an entire tile
SPT can process very large scenes at once. Let's visualize the output.

In [None]:
# Visualize the hierarchical partition
show(
    nag, 
    class_names=CLASS_NAMES, 
    ignore=NUM_CLASSES,
    class_colors=CLASS_COLORS,
    max_points=100000
)

However, for memory reasons, the visualization cannot display all points. Let's have a look at a smaller area.

## Selecting a portion of the hierarchical partition
The NAG structure can be subselected using `nag.select()`.

This function expects an `int` specifying the partition level from which we should select, along with an index or a mask in the form or a `list`, `numpy.ndarray`, `torch.Tensor`, or `slice`.
This index/mask describes which nodes to select at the specified level.

The output NAG will only contain children, parents and edges of the selected nodes.

In [None]:
# Pick a center and radius for the spherical sample
center = torch.tensor([[40, 115, 0]]).to(nag.device)
radius = 10

# Create a mask on level-0 (ie points) to be used for indexing the NAG 
# structure
mask = torch.where(torch.linalg.norm(nag[0].pos - center, dim=1) < radius)[0]

# Subselect the hierarchical partition based on the level-0 mask
nag_visu = nag.select(0, mask)

In [None]:
# Visualize the sample
show(
    nag_visu,
    class_names=CLASS_NAMES,
    ignore=NUM_CLASSES,
    class_colors=CLASS_COLORS, 
    max_points=100000
)

## Visualizing the superpoint graphs
Let's have a closer look to visualize the graph connecting superpoints by setting `centroids=True` and `h_edge=True`.

In [None]:
# Visualize the sample
show(
    nag_visu,
    class_names=CLASS_NAMES,
    ignore=NUM_CLASSES,
    class_colors=CLASS_COLORS, 
    max_points=100000, 
    centroids=True, 
    h_edge=True, 
    h_edge_width=2
)

## Side-by-side visualization mode
By setting `gap` to a chosen 3D offset, we can visualize all partition levels at once. Besides, setting `v_edge=True` will display the vertical edges connecting superpoints with their children.

In [None]:
# Visualize the sample
show(
    nag_visu,
    figsize=1000,
    class_names=CLASS_NAMES,
    ignore=NUM_CLASSES,
    class_colors=CLASS_COLORS, 
    max_points=100000, 
    centroids=True, 
    v_edge=True, 
    v_edge_width=1, 
    gap=[0, 0, 10]
)

## Exporting your visualization to HTML
You can export your interactive visualization to HTML. 
You can then share your visualization, to be opened on any web browser with internet connection.

To export a visualization, simply specify a `path` to which the file should be saved.
Additionally, you may set a `title` to be displayed in your HTML.

In [None]:
# Visualize the sample
show(
    nag_visu,
    figsize=1600,
    class_names=CLASS_NAMES,
    ignore=NUM_CLASSES,
    class_colors=CLASS_COLORS, 
    max_points=100000,
    title="My Interactive Visualization Partition", 
    path="my_interactive_visualization.html"
)

## Going further with visualization
See the commented code in `src.visualization` for more visualization options.