In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

# Add the project's files to the python path
# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script
file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook
sys.path.append(file_path)

import torch
from src.visualization import show
from src.datasets.s3dis import CLASS_NAMES, CLASS_COLORS
from src.datasets.s3dis import S3DIS_NUM_CLASSES as NUM_CLASSES
from src.transforms import *

The main data structures of this project are `Data` and `NAG`.

`Data` stores a single-level graph. 
It inherits from `torch_geometric`'s `Data` and has a similar behavior (see the [official documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data) for more on this). 
Important specificities of our `Data` object are:
- `Data.super_index` stores the parent's index for each node in `Data`
- `Data.sub` holds a `Cluster` object indicating the children of each node in `Data`
- `Data.to_trimmed()` works like `torch_geometric`'s `Data.coalesce()` with the additional constraint that (i,j) and (j,i) edges are considered duplicates
- `Data.save()` and `Data.load()` allow optimized, memory-friedly I/O operations
- `Data.select()` indexes the nodes à la numpy

`NAG` (Nested Acyclic Graph) stores the hierarchical partition in the form of a list of `Data` objects.
Important specificities of our `Data` object are:
- `NAG[i]` returns a `Data` object holding the partition level `ì`
- `NAG.get_super_index()` returns the index mapping nodes from any level `i` to `j` with `i<j`
- `NAG.get_sampling()` produces indices for sampling the superpoints with certain constraints
- `NAG.save()` and `NAG.load()` allow optimized, memory-friedly I/O operations
- `NAG.select()` indexes the nodes of a specified partition level à la numpy and updates the rest of the `NAG` structure accordingly

## Load a NAG

In [None]:
nag = torch.load('demo_nag.pt')

In [None]:
# Print general info about the NAG
print(nag)

In [None]:
# Loop over the partition levels and print general info about each Data
for i_level, data in enumerate(nag):
    print(f"Level-{i_level}:\n{data}\n")

## Visualizing a NAG

In [None]:
# Visualize the hierarchical partition
show(
    nag, 
    class_names=CLASS_NAMES, 
    ignore=NUM_CLASSES,
    class_colors=CLASS_COLORS,
    max_points=100000,
    centroids=True,
    h_edge=True
)

## Selecting a portion of the hierarchical partition
The NAG structure can be subselected using `nag.select()`.

This function expects an `int` specifying the partition level from which we should select, along with an index or a mask in the form or a `list`, `numpy.ndarray`, `torch.Tensor`, or `slice`.
This index/mask describes which nodes to select at the specified level.

The output NAG will only contain children, parents and edges of the selected nodes.

In [None]:
# Pick a center and radius for the spherical sample
center = nag[0].pos.mean(dim=0)
radius = 1

# Create a mask on level-0 (ie points) to be used for indexing the NAG 
# structure
mask = torch.where(torch.linalg.norm(nag[0].pos - center, dim=1) < radius)[0]

# Subselect the hierarchical partition based on the level-0 mask
nag_visu = nag.select(0, mask)

In [None]:
# Visualize the sample
show(
    nag_visu,
    class_names=CLASS_NAMES,
    ignore=NUM_CLASSES,
    class_colors=CLASS_COLORS, 
    max_points=100000,
    centroids=True,
    h_edge=True
)