# 1. Reading and Visualizing Tree Point Clouds

### Create Label-Mapping and Label Colors

In [None]:
import numpy as np

FOR_Instance_num_classes = 3

ID2TRAINID = np.asarray([
    FOR_Instance_num_classes,   # 0 Unclassified        ->  3 Ignored
    1,                          # 1 Low vegetation      ->  1 Low vegetation
    0,                          # 2 Terrain             ->  0 Ground
    FOR_Instance_num_classes,   # 3 Out-points          ->  3 Ignored
    2,                          # 4 Stem                ->  2 Tree
    2,                          # 5 Live branches       ->  2 Tree
    2,                          # 6 Woody branches      ->  2 Tree
])

FOR_Instance_CLASS_NAMES = [
    'Ground',
    'Low vegetation',
    'Tree',
    'Ignored']

# Class color palette
FOR_Instance_CLASS_COLORS = np.asarray([
    [243, 214, 171],
    [204, 213, 174],
    [ 70, 115,  66],
    [  0,   0,   0]
])

### Preparing a Data reader

Data object is a simple class based on PyG's Data object for holding point clouds (graphs). 

#### .copy() Problem

In [None]:
import os
import sys
import torch

file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook
sys.path.append(file_path)

import laspy
from src.data import Data
from src.utils.color import to_float_rgb

data = Data()
las_filepath = "/home/valerio/git/superpoint_transformer_vschelbi/data/FORinstance/raw/NIBIO/plot_1_annotated.las"
las_vancouver_filepath = "/home/valerio/git/superpoint_transformer_vschelbi/data/491000_5454000/491000_5454000.las"

las = laspy.read(las_filepath)
las_vancouver = laspy.read(las_vancouver_filepath)

print(las.header)
dimensions = las.point_format.dimension_names
print("Available dimensions: ")
for dim in dimensions:
    print(dim)

print(las_vancouver.header)
dimensions = las_vancouver.point_format.dimension_names
print("Available dimensions: ")
for dim in dimensions:
    print(dim)


x_pos = las["X"].copy()
y_pos = las["Y"].copy()
z_pos = las["Z"].copy()
x_tensor = torch.tensor(x_pos)
y_tensor = torch.tensor(y_pos)
z_tensor = torch.tensor(z_pos)
torch_pos = torch.stack([x_tensor, y_tensor, z_tensor], dim=-1)


xyz = True
# populate the Data object with point coordinates
if xyz:
    # Apply the scale provided by the LAS header
    pos = torch.stack([
        torch.tensor(las[axis].copy())
        for axis in ["X", "Y", "Z"]], dim=-1)
    pos *= las.header.scale
    pos_offset = pos[0]
    data.pos = (pos - pos_offset).float()
    data.pos_offset = pos_offset


In [None]:
import os
import sys

# Add the project's files to the python path
# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script
file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook
sys.path.append(file_path)

import laspy
import torch
from src.data import Data, InstanceData
from src.utils.color import to_float_rgb
from torch_geometric.nn.pool.consecutive import consecutive_cluster


def read_FORinstance_plot(
    filepath,
    xyz=True,
    rgb=True,
    intensity=True,
    semantic=True,
    instance=False,
    remap=True,
    max_intensity=600):
    """Read a FORinstance file saved as LAS.

    :param filepath: str
        Absolute path to the LAS file
    :param xyz: bool
        Whether XYZ coordinates should be saved in the output Data.pos
    :param rgb: bool
        Whether RGB colors should be saved in the output Data.rgb
    :param intensity: bool
        Whether intensity should be saved in the output Data.rgb
    :param semantic: bool
        Whether semantic labels should be saved in the output Data.y
    :param instance: bool
        Whether instance labels should be saved in the output Data.obj
    :param remap: bool
        Whether semantic labels should be mapped from their FORinstance ID
        to their train ID
    :param max_intensity: float
        Maximum value used to clip intensity signal before normalizing 
        to [0, 1]
    """

    # create an empty Data object
    data = Data()
    las = laspy.read(filepath)

    # populate the Data object with point coordinates
    if xyz:
        # Apply the scale provided by the LAS header
        pos = torch.stack([
            torch.tensor(las[axis].copy())
            for axis in ["X", "Y", "Z"]], dim=-1)
        pos *= las.header.scale
        pos_offset = pos[0]
        data.pos = (pos - pos_offset).float()
        data.pos_offset = pos_offset
    
    # Populate data with point RGB colors
    if rgb:
        # RGB stored in uint16 lives in [0, 65535]
        data.rgb = to_float_rgb(torch.stack([
            torch.FloatTensor(las[axis].astype('float32') / 65535)
            for axis in ["red", "green", "blue"]], dim=-1))
    
    # Populate data with point LiDAR intensity
    if intensity:
        # Heuristic to bring the intensity distribution in [0, 1]
        data.intensity = torch.FloatTensor(
            las['intensity'].astype('float32')
        ).clip(min=0, max=max_intensity) / max_intensity

    # Populate data with point semantic segmentation labels
    if semantic:
        y = torch.LongTensor(las['classification'])
        data.y = torch.from_numpy(ID2TRAINID)[y] if remap else y

    if instance:
        idx = torch.arange(data.num_points)
        obj = torch.LongTensor(las['treeID'])
        obj = consecutive_cluster(obj)[0]
        count = torch.ones_like(obj)
        y = torch.LongTensor(las['classification'])
        y = torch.from_numpy(ID2TRAINID)[y] if remap else y
        data.obj = InstanceData(idx, obj, count, y, dense=True)

    return data
    


### Read Data and Visualize

In [None]:
filepath = "/home/valerio/git/superpoint_transformer_vschelbi/data/FORinstance/raw/SCION/plot_87_annotated.las"
data = read_FORinstance_plot(filepath, rgb=False)

In [None]:
data.show(max_points=100000, keys=['intensity'], class_names=FOR_Instance_CLASS_NAMES, class_colors=FOR_Instance_CLASS_COLORS)

### No Tiling Done On This File

In [None]:
# tiling code would go here, SPT provides two tiling strategies:
# 1. SampleXYTiling, when clouds have a simple, convex, axis-aligned horizontal layout
# 2. SampleRecursiveMainXYAxisTiling => when couds have a complex horizontal layouts (like streets etc.)

# Using a pretrained model for inference
First get the same transforms as the pretrained model used, in this case DALES (in `configs/experiment`)

#### Retrieving Config and transforms from pretrained model

In [None]:
from src.utils import init_config
from src.transforms import instantiate_datamodule_transforms

cfg = init_config(overrides=[f"experiment=semantic/dales"])
#cfg.keys()

transforms_dict = instantiate_datamodule_transforms(cfg.datamodule)
#transforms_dict

# applying the transform:

#### Applying transform
1. apply `pre_transofrm`
2. simulate the behavior of the dataset's input/ output behavior with
    only `point_load_keys` and `segment_load_keys` loaded from disk
3. apply `on_device_test_transform`

In [None]:
# 1. Apply pre-transforms
nag = transforms_dict['pre_transform'](data)

# 2. Sim I/O behavior
from src.transforms import NAGRemoveKeys
nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag)
nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag)

nag = nag.cuda()
nag = transforms_dict['on_device_test_transform'](nag)

In [None]:
nag

In [None]:
nag.show(class_names=FOR_Instance_CLASS_NAMES, class_colors=FOR_Instance_CLASS_COLORS, keys=nag[0].keys, centroids=True, h_edge=True)

#### Instantiating the pretrained model from `configs/` and a `*.ckpt`

In [None]:
import hydra
from src.utils import init_config

# Path to the downloaded checkpoint file
ckpt_path = "/home/valerio/git/superpoint_transformer_vschelbi/pretrained_models/spt-2_dales.ckpt"
cfg = init_config(overrides=[f"experiment=semantic/dales"])

# Instantiate the model and load the weights
model = hydra.utils.instantiate(cfg.model)
model = model._load_from_checkpoint(ckpt_path)

#### Applying the SPT

In [None]:
# Set model to inference mode and onto same device as the input data
model = model.eval().to(nag.device)

# Infernce, returns a task-specific output object carrying the model's predictions
with torch.no_grad():
    output = model(nag)

output of the semantic segmentation is `SemanticSemgentationOuput` object. class dedictated to holding onto predictions in `output.semantic_pred` and facilitates post-processing operations.

In [None]:
output.semantic_pred.shape, nag.num_points

# bring predicitions to level-0 and save under the semantic_pred attribute in Data
nag[0].semantic_pred = output.voxel_semantic_pred(super_index=nag[0].super_index)

#### Visualize the predictions
For better visualization, use the DALES `CLASS_NAMES` and `CLASS_COLORS`, as the model was trained on those classes. The predicted labels DO NOT align with those of the FORinstance dataset.

In [None]:
from src.datasets.dales import CLASS_NAMES as DALES_CLASS_NAMES
from src.datasets.dales import CLASS_COLORS as DALES_CLASS_COLORS

nag.show(
    class_names=DALES_CLASS_NAMES,
    class_colors=DALES_CLASS_COLORS,
)

### Outlook
This is as far as a pretrained model can get me on the FORinstance dataset. For actually predicting the instances (of trees) and identifying the classes in the FORinstance dataset, I will ned to train a dedicated model on the FORinstance data. Besides, I will have to adjust the preprocessing steps in `pre_transform`. Different parameters may produce partitions that better respect the semantic boundaries of the FORinstance dataset.