# Pointcloud-to-Hypergraph PointNet Lifting Tutorial

### Imports and utilities

In [None]:
# With this cell any imported module is reloaded before each cell execution
%load_ext autoreload
%autoreload 2
import os

import rootutils
from torch_geometric.data import Data

from modules.data.load.loaders import SimplicialLoader
from modules.data.preprocess.preprocessor import PreProcessor
from modules.utils.utils import (
    describe_data,
    load_dataset_config,
    load_model_config,
    load_transform_config,
)

## Loading the Dataset - Wall Shear Stress on the Artery

The Dataset is described in detail in [this paper](https://arxiv.org/abs/2212.05023). Wall shear stress is a useful medical biomarker that has been linked to coronary artery disease. This quantity can be roughly estimated from the artery shape.

In [None]:
dataset_name = "wall_shear_stress"
dataset_config = load_dataset_config(dataset_name)
loader = SimplicialLoader(dataset_config)

dataset = loader.load()

print("\nDataset:")
dataset

Our dataset consists of 2000 triangular meshes which are concatenated and represented as simplicial complex. The `slices_pos` and `slices_face` can be used to extract individual meshes. The node attribute is the geodesic distance to the artery inlet (to give the mesh a direction) and the face attribute is the surface normal.

In [None]:
sample_idx = 0

# Extract individual sample
slice_pos = slice(dataset.slices_pos[sample_idx], dataset.slices_pos[sample_idx + 1])
slice_face = slice(dataset.slices_face[sample_idx], dataset.slices_face[sample_idx + 1])

data = Data(
    x=dataset.x[slice_pos],
    y=dataset.y[slice_pos],
    pos=dataset.pos[slice_pos],
    face=dataset.face[:, slice_face],
    # x_2=dataset.x_2[slice_face],
    # incidence_2=dataset.incidence_2[slice_pos, slice_face]  # not supported by PyTorch,
    num_features=dataset.num_features,
)

print("Data sample:")
data

## Loading and Applying the Lifting

In [None]:
transform_type = "liftings"
transform_id = "pointcloud2hypergraph/pointnet_lifting"
transform_config = {"lifting": load_transform_config(transform_type, transform_id)}

# Adapt sampling ratio and cluster radius
transform_config["lifting"]["sampling_ratio"] = 0.2
transform_config["lifting"]["cluster_radius"] = 0.1

lifted_data = PreProcessor(
    data, transform_config, os.path.join(rootutils.find_root(), dataset_config.data_dir)
)
describe_data(lifted_data)

This hypergraph represents the first set abstraction layer that is used by [PointNet++](https://arxiv.org/abs/1706.02413). To construct a complet PointNet++ out of this, we would have to recursively apply the lifting while regarding the previous hyperedges as new "hyper-nodes".

## Running inference

In [None]:
from modules.models.hypergraph.unigcn import UniGCNModel

model_type = "hypergraph"
model_id = "unigcn"
model_config = load_model_config(model_type, model_id)

model = UniGCNModel(model_config, dataset_config)

print("\nModel output:")
model(lifted_data)