In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

# Add the project's files to the python path
# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script
file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook
sys.path.append(file_path)

import hydra
from src.utils import init_config, compute_panoptic_metrics, \
    compute_panoptic_metrics_s3dis_6fold, grid_search_panoptic_partition, \
    oracle_superpoint_clustering
import torch
from src.transforms import *
from src.utils.widgets import *
from src.data import *

# Very ugly fix to ignore lightning's warning messages about the
# trainer and modules not being connected
import warnings
warnings.filterwarnings("ignore")

## Select your device, experiment, split, and pretrained model

In [None]:
device_widget = make_device_widget()
task_widget, expe_widget = make_experiment_widgets()
split_widget = make_split_widget()
ckpt_widget = make_checkpoint_file_search_widget()

In [None]:
# Summarizing selected task, experiment, split, and checkpoint
print(f"You chose:")
print(f"  - device={device_widget.value}")
print(f"  - task={task_widget.value}")
print(f"  - split={split_widget.value}")
print(f"  - experiment={expe_widget.value}")
print(f"  - ckpt={ckpt_widget.value}")

## Parsing the config files
Hydra and OmegaConf are used to parse the `yaml` config files.

❗Make sure you selected a **ckpt file relevant to your experiment** in the previous section. 
You can use our pretrained models for this, or your own checkpoints if you have already trained a model.

In [None]:
# Parse the configs using hydra
cfg = init_config(overrides=[
    f"experiment={task_widget.value}/{expe_widget.value}",
    f"ckpt_path={ckpt_widget.value}"
])

## Datamodule and model instantiation

In [None]:
# Instantiate the datamodule
datamodule = hydra.utils.instantiate(cfg.datamodule)
datamodule.prepare_data()
datamodule.setup()

# Pick among train, val, and test datasets. It is important to note that
# the train dataset produces augmented spherical samples of large 
# scenes, while the val and test dataset load entire tiles at once
if split_widget.value == 'train':
    dataset = datamodule.train_dataset
elif split_widget.value == 'val':
    dataset = datamodule.val_dataset
elif split_widget.value == 'test':
    dataset = datamodule.test_dataset
else:
    raise ValueError(f"Unknown split '{split_widget.value}'")

# Print a summary of the datasets' classes
dataset.print_classes()

# Instantiate the model
model = hydra.utils.instantiate(cfg.model)

# Load pretrained weights from a checkpoint file
if ckpt_widget.value is not None:
    model = model._load_from_checkpoint(cfg.ckpt_path)

# Move model to selected device
model = model.eval().to(device_widget.value)

## Oracles on a tile sample
We design oracles for estimating the maximum achievable performance of our superpoint-graph-clustering approach on a point cloud. Here, it is important to note that these metrics are computed on a tile but not on the entire dataset. The oracles are computed on a given superpoint partition level. Based on the quality of the partition, we estimate the following:

- `semantic_segmentation_oracle`: assign to each superpoint the most frequent label among the points it contains
- `panoptic_segmentation_oracle`: same as for semantic segmentation + assign each superpoint to the target instance it overlaps the most
- `oracle_superpoint_clustering`: same as for semantic segmentation + assign to each edge the target affinity + compute the graph clustering to form instance predictions

Of course, these oracles are affected by how the superpoint partition has been computed. Besides, the latter is also affected by the graph clustering parameters.

In [None]:
# Get the panoptic annotations for a tile from the dataset 
obj = dataset[0][1].obj

In [None]:
# Compute the semantic segmentation oracle
obj.semantic_segmentation_oracle(dataset.num_classes)

In [None]:
# Compute the panoptic segmentation oracle without graph clustering
obj.panoptic_segmentation_oracle(dataset.num_classes, stuff_classes=dataset.stuff_classes)

In [None]:
# Compute the panoptic segmentation oracle with graph clustering
oracle_superpoint_clustering(
    dataset[0],
    dataset.num_classes,
    dataset.stuff_classes,
    mode='pas',
    graph_kwargs=dict(
        radius=0.1),
    partition_kwargs=dict(
        regularization=0.1,
        x_weight=1e-3,
        cutoff=300))

## Grid-searching partition parameters on a tile sample
Our SuperCluster model is trained to predict the input for a graph clustering problem whose solution is a panoptic segmentation of the scene.
Interestingly, with our formulation, the model is **only supervised with local node-wise and edge-wise objectives, without ever needing to compute an actual panoptic partition of the scene during training**.

At inference time, however, we need to decide on some parameters for our graph clustering algorithm.
To this end, a simple post-training grid-search can be used.

We find that similar parameters maximize panoptic segmentation results on all our datasets. 
Here you, we provide utilities for helping you grid-search parameters yourself. See `grid_search_panoptic_partition` docstring for more details on how to use this tool.

In [None]:
# Grid search graph clustering parameters
output, partitions, results = grid_search_panoptic_partition(
    model,
    datamodule.val_dataset,
    i_cloud=0,
    graph_kwargs=dict(
        radius=0.1),
    partition_kwargs=dict(
        regularization=[2e1, 1e1, 5],
        x_weight=[5e-2, 1e-2, 1e-3, 1e-4],
        cutoff=300),
    mode='pas')

## Running evaluation on a whole dataset
The above grid search only computes the panoptic segmentation metrics on a single point cloud.
In this section, we provide tools for computing the panoptic metrics on a whole dataset. 

In [None]:
panoptic, instance, semantic = compute_panoptic_metrics(
    model,
    datamodule,
    stage='val',
    graph_kwargs=dict(
        radius=0.1),
    partition_kwargs=dict(
        regularization=1e1,
        x_weight=5e-2,
        cutoff=300))

### S3DIS 6-fold metrics
For S3DIS 6-fold metrics, we provide the following utility for computing metrics.

In [None]:
fold_ckpt = {
    1: "/path/to/your/s3dis/checkpoint/fold_1.ckpt",
    2: "/path/to/your/s3dis/checkpoint/fold_2.ckpt",
    3: "/path/to/your/s3dis/checkpoint/fold_3.ckpt",
    4: "/path/to/your/s3dis/checkpoint/fold_4.ckpt",
    5: "/path/to/your/s3dis/checkpoint/fold_5.ckpt",
    6: "/path/to/your/s3dis/checkpoint/fold_6.ckpt",
}

experiment_config = f"experiment={task_widget.value}/{expe_widget.value}"

In [None]:
_ = compute_panoptic_metrics_s3dis_6fold(
    fold_ckpt,
    experiment_config,
    stage='val', 
    graph_kwargs=dict(
        radius=0.1),
    partition_kwargs=dict(
        regularization=10,
        x_weight=1e-3,
        cutoff=300),
    verbose=False)