<a href="https://colab.research.google.com/github/shahabday/graph-neural-networks/blob/main/0_PointCloudSemanticSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# PLEASE, MAKE A COPY OF THIS COLAB BEFORE RUNNING ANYTHING.

import torch
import os
# Install required packages.
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install -q open3d



In [None]:
import gdown
import logging
import numpy as np
import open3d as o3d
import random
import plotly.graph_objects as go

from multiprocessing import cpu_count
from plotly.subplots import make_subplots
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeScale, FixedPoints
from typing import Optional, Tuple

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)

In [None]:
# A helper function to set a random seed
def seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed(12345)

# A helper function that loads the point cloud data and leaves only X,Y,Z coordinates with the point labels.
def load_point_cloud(file_path: str) -> np.ndarray:
    file_path = os.path.abspath(file_path)
    if not file_path.endswith('.npy'):
        raise ValueError("File must be a .npy file!")

    points = np.load(file_path)
    points_labels = points[:, [0, 1, 2, -1]]
    return points_labels

# A helper visualization function that is versatile and can take either raw point clouds from a file or work directly with
# PyG data objects. This function can visuailize either 1 or 2 point clouds side by side.
def visualize(
        point_cloud_fname1: Optional[str]=None,
        point_cloud_fname2: Optional[str]=None,
        point_cloud_graph1: Optional[Data]=None,
        point_cloud_graph2: Optional[Data]=None,
        edge_indices: Tuple[Optional[torch.Tensor], ...]=(None, None),
        indices: Tuple[Optional[torch.Tensor], ...]=(None, None),
        show_both: bool=False,
        name1: Optional[str]=None,
        name2: Optional[str]=None
) -> None:
    """Visualize one or two point clouds, either from files or PyG Data objects, optionally displaying edge connections.

    Args:
        point_cloud_fname1 (Optional[str]): File name for the 1st point cloud. If given, overrides `point_cloud_graph1`.
        point_cloud_fname2 (Optional[str]): File name for the 2nd point cloud. If given, overrides `point_cloud_graph2`.
        point_cloud_graph1 (Optional[Data]): PyG Data object for the 1st point cloud.
        Used if `point_cloud_fname1` is not given.
        point_cloud_graph2 (Optional[Data]): PyG Data object for the 2nd point cloud.
        Used if `point_cloud_fname2` is not given.
        edge_indices (Tuple[Optional[torch.Tensor], ...]): Tuple containing edge indices for the point clouds.
        Each entry should be a tensor of shape [2, num_edges] or None.
        indices (Tuple[Optional[torch.Tensor], ...]): Tuple containing indices of points to highlight in the point clouds.
        Each entry should be a tensor of indices or None.
        show_both (bool): Whether to visualize one or two point clouds. If True, visualizes both point clouds side by side.
        name1 (Optional[str]): Optional name for the first point cloud to use in the plot title.
        name2 (Optional[str]): Optional name for the second point cloud to use in the plot title.
    """

    if not any([point_cloud_fname1, point_cloud_fname2, point_cloud_graph1, point_cloud_graph2]):
        logger.warning("Provide at least one point cloud file name or PyG Data object to visualize!")
        return

    if (point_cloud_fname1 and point_cloud_graph1) or (point_cloud_fname2 and point_cloud_graph2):
        logger.warning("Provide either a file with a specific point cloud or a PyG Data object, but not both")
        return None

    def load_data(point_cloud_fname, point_cloud_graph):
        if point_cloud_fname:
            points_labels = load_point_cloud(point_cloud_fname)
            pos = points_labels[:, :-1]
            labels = points_labels[:, -1]
        else:
            pos = point_cloud_graph.pos
            labels = point_cloud_graph.y
        return pos, labels

    pos1, labels1 = load_data(point_cloud_fname1, point_cloud_graph1)
    pos2, labels2 = None, None
    if show_both:
        pos2, labels2 = load_data(point_cloud_fname2, point_cloud_graph2)

    if show_both:
        titles = [name if name else f"Point Cloud {num}" for num, name in enumerate([name1, name2])]
        fig = make_subplots(
            rows=1, cols=2,
            specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
            subplot_titles=[
                os.path.basename(point_cloud_fname1) if point_cloud_fname1 else titles[0],
                os.path.basename(point_cloud_fname2) if point_cloud_fname2 else titles[1]
            ]
        )
    else:
        fig = go.Figure()

    def add_scatter_traces(fig, pos, labels, row=None, col=None, edge_index=None, index=None):
        if edge_index is not None:
            for (src, dst) in edge_index.t().tolist():
                src = pos[src].tolist()
                dst = pos[dst].tolist()
                fig.add_trace(
                    go.Scatter3d(
                        x=[src[0], dst[0]], y=[src[1], dst[1]], z=[src[2], dst[2]],
                        mode='lines',
                        line=dict(width=0.5, color='black'),
                        opacity=0.5
                    ),
                    row=row, col=col
                )
        if index is None:
            fig.add_trace(
                go.Scatter3d(
                    x=pos[:, 0], y=pos[:, 1], z=pos[:, 2],
                    mode='markers',
                    marker=dict(size=1.5, color=labels, colorscale="Viridis")
                ),
                row=row, col=col
            )
        else:
            mask = torch.zeros(pos.size(0), dtype=torch.bool)
            mask[index] = True
            fig.add_trace(
                go.Scatter3d(
                    x=pos[~mask, 0], y=pos[~mask, 1], z=pos[~mask, 2],
                    mode='markers',
                    marker=dict(size=1.5, color='lightgray')
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter3d(
                    x=pos[mask, 0], y=pos[mask, 1], z=pos[mask, 2],
                    mode='markers',
                    marker=dict(size=1.5, color=labels[mask], colorscale="Viridis")
                ),
                row=row, col=col
            )

    if show_both:
        add_scatter_traces(fig, pos1, labels1, row=1, col=1, edge_index=edge_indices[0], index=indices[0])
        add_scatter_traces(fig, pos2, labels2, row=1, col=2, edge_index=edge_indices[1], index=indices[1])
    else:
        add_scatter_traces(fig, pos1, labels1, edge_index=edge_indices[0], index=indices[0])

    fig.update_layout(
        scene=dict(
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False),
            aspectmode='data'
        )
    )

    if show_both:
        fig.update_layout(
            scene1=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False),
                aspectmode='data'
            ),
            scene2=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False),
                aspectmode='data'
            )
        )

    fig.show()

# Semantic segmentation with 3D point clouds and GNNs


#### Sources:

The notebook material was partially taken and modified from:

1. [PointNet++ paper](https://arxiv.org/abs/1706.02413)
2. [PyG PointNet++ examples](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pointnet2_classification.py)
3. [VKitti3D Dataset](https://github.com/VisualComputingInstitute/vkitti3D-dataset?tab=readme-ov-file)

#### Additional material:
- [Virtual KITTI 3D dataset and it's description](https://github.com/VisualComputingInstitute/vkitti3D-dataset?tab=readme-ov-file)
- [The original PointNet++ paper](https://arxiv.org/abs/1706.02413)
- [The PyG tutorial on point cloud classification](https://colab.research.google.com/drive/1D45E5bUK3gQ40YpZo65ozs7hg5l-eo_U?usp=sharing#scrollTo=xqw5fnp5O832)
- [Good old Euler rotation matrices](https://en.wikipedia.org/wiki/Rotation_matrix)


<div align="center">
    <img src="https://drive.google.com/uc?export=view&id=1J1kH7nTOZ5XwViS7Fu5Yb9xcD3nWpYZL" width="500" height="500"/>
    <br>
</div>




**The goal of point cloud segmentation is to assign a specific label or category to every point in the cloud, effectively partitioning the point cloud into different meaningful regions or objects.**

Let's kick off with a small exercise.

## Exercise 1

1. Download a mesh object that represents a bunny made by the Standford university. Run the following command in a new cell to download the bunny:

    `!wget https://graphics.stanford.edu/~mdfisher/Data/Meshes/bunny.obj`

2. Read the bunny object with `open3d` package into a mesh object (don't worry if you don't know what a mesh is, it is not important here). Use this command:

    `bunny_mesh = o3d.io.read_triangle_mesh("bunny.obj")`

3. Use the function `mesh2cloud` provided below that transforms an open3D TriangleMesh object into a point cloud to get a *cloud bunny*. Choose something around 30K points:

    ```python
    def mesh2cloud(mesh_obj: o3d.geometry.TriangleMesh, num_points: int) -> np.array:
        point_cloud = mesh_obj.sample_points_uniformly(number_of_points=num_points)
        return np.asarray(point_cloud.points)
    ```
    
   You will get a 2D numpy array of the shape: $$N \times 3,$$
   **where $N$ is the number of points in your point cloud and 3 represent the 3 spatial coordinates, X, Y, Z. So each row of the matrix is nothing more than a position of the point in space.**

4. Assign to each point on the bunny's body a color that would gradually change from 0 to 1 for the whole bunny. Think of this color as an output of a neural network that segments out different body parts of a bunny and assigns a specific color to them. Use the following line to create the colors for each point and concatenate them as the 4th coordinate to the given ones, which are X, Y, Z.

    `colors = np.linspace(0, 1, num_points)`

   **hint:** `use np.concatenate()`, the resulting shape should be $N \times 4$.

5. **You main contibution goes here:**
    - Permute in a random order all the **spatial** points in the point cloud, **leaving the colors unchanged**.
        - **hint:** you can use `np.random.permutation` to help you permute the points.
    - Visualize the original bunny and the bunny with permuted spatial coordinates, using the `visualize_point_clouds` function given in the end of the exercise description. You can start with visualizing just the original bunny, setting `show_both` argument to `False`.
    
Think about the following:

- What "quantity(s)" is(are) permutation invariant in the setting of this toy experiment?
- What types of invariances/equivariances you can think of that would be desirable for a neural network operating on the point clouds domain?

    ```python
    # Visualization function
    def visualize_point_clouds(
        points1: np.ndarray,
        points2: np.ndarray,
        title1: str,
        title2: str,
        show_both: bool=True
    ):
        fig = make_subplots(
            rows=1, cols=2,
            specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
            subplot_titles=[title1, title2]
        )

        fig.add_trace(
            go.Scatter3d(
                x=points1[:, 0], y=points1[:, 1], z=points1[:, 2],
                mode='markers',
                marker=dict(size=1.5, color=points1[:, -1], colorscale='Viridis'),
                name=title1
            ),
            row=1, col=1
        )

        if show_both:
            fig.add_trace(
                go.Scatter3d(
                    x=points2[:, 0], y=points2[:, 1], z=points2[:, 2],
                    mode='markers',
                    marker=dict(size=1.5, color=points2[:, -1], colorscale='Viridis'),
                    name=title2
                ),
                row=1, col=2
            )

        fig.update_layout(
            scene1=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False),
                aspectmode='data'
            ),
            scene2=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False),
                aspectmode='data'
            )
        )
        
        fig.show()
    ```


Extra tiny fun for those who's finished the exercise and is getting bored:

- Use Euler 3D rotation matrices to rotate the bunny around some axis and visualize the output. The will come in handy soon.

    ```python
    # Rotation matricies around X, Y or Z-axis, theta given in radians
    def R_axis(theta: float, axis: str) -> np.ndarray:
        axis = axis.lower()
        assert axis in ['x', 'y', 'z'], "Axis must be either x, y or z"
        if axis == 'x':
            return np.array([
                [1, 0, 0],
                [0, np.cos(theta), -np.sin(theta)],
                [0, np.sin(theta), np.cos(theta)]
        ] )
        if axis == 'y':
            return np.array([
                [np.cos(theta), 0, np.sin(theta)],
                [0, 1, 0],
                [-np.sin(theta), 0, np.cos(theta)]
            ])
        if axis == 'z':
            return np.array([
                [np.cos(theta), -np.sin(theta), 0],
                [np.sin(theta), np.cos(theta), 0],
                [0, 0, 1]
            ])
    ```

In [None]:
# ..... YOUR CODE HERE .....

## Real-life point clouds

- Let's move on to a real-life example and consider a problem of object segmentation in self-driving cars that use a LiDAR (light detection and ranging) to understand the road situation around a car.
- We will work with a tiny subset of the Virtual KITTI 3D dataset.


### Dataset exploration

Let's download the dataset, visualize it and proceed with transforming it to a PyG dataset format.

In [None]:
def download_data(file_id: str, file_name: str, split, destination_dir: str="data") -> None:
    url = f'https://drive.google.com/uc?id={file_id}'
    destination_dir = os.path.join(destination_dir, split, "raw")
    if not os.path.exists(destination_dir):
        os.makedirs(destination_dir)
    destination_path = os.path.join(destination_dir, file_name)
    gdown.download(url, destination_path, quiet=True)

file_ids2file_names_splits = [
    ("1vHAkZLAqgqoIf9VowXUJLDoyA8JD6Zbs", "frame_0.npy", "train"),
    ("19C0AciVSdZeZPeAurpBI8RsKOwqgKvkS", "frame_1.npy", "train"),
    ("1Eo-ekBp06KSsJ4Au_PfTt22s7b5Wmxrr", "frame_2.npy", "train"),
    ("1RPveTSlgBA38QoBlf_GQImXrHhqhfXhs", "frame_3.npy", "train"),
    ("14Han3ESj_zmKJCKTlyKRtJgqsy1kWYeF", "frame_4.npy", "train"),
    ("1Lmg2PBNWpkovzlURrhkksI1EU_ZwAeub", "frame_0.npy", "test")
]

# Download the point clouds and save them as 5 frames
[download_data(*file_id_file_name_split) for file_id_file_name_split in file_ids2file_names_splits]

[None, None, None, None, None, None]

In [None]:
visualize(
    point_cloud_fname1="data/train/raw/frame_0.npy",
    show_both=False
) # it will take around 1 minute

In [None]:
# Gather some info about the point clouds we have
frames = ["data/train/raw/" + f"frame_{num}.npy" for num in range(5)]
frames += ["data/test/raw/" + "frame_0.npy"]

# Initialize a set for the labels of the points
unique_labels_prev = set()

for num, frame in enumerate(frames):
    points_labels = load_point_cloud(frame)
    unique_labels_curr = set(np.unique(points_labels[:, -1]))
    print(f"\nShape of the first frame is: {points_labels.shape}")
    print(f"The number of unique label categories for frame {num} is: {len(unique_labels_curr)}")

    if num > 0:
        new_label = unique_labels_curr -unique_labels_prev
        if new_label:
            print(f"A new label {new_label} appears in the current frame, which wasn't present before \n")
    unique_labels_prev = unique_labels_curr

    if num == len(frames) - 1:
        print(f"\nThe labels present in the frames: {unique_labels_prev}")


- The dataset consists of 90 frames of simulated road scenes, using the original KITTI dataset. We will work only with 6 frames.
- Each point cloud represents a LiDAR signal, collected from a roof of an imaginery car that drives along the virtual streets.
- Each frame has around 400K data points and 11 unique labels.
- The total number of labels is 14 and given by the table below:

| Label ID | Semantics  | RGB             | Color       |
|----------|------------|-----------------|-------------|
| 0  | Terrain          | [200, 90, 0]    | brown       |
| 1  | Tree             | [0, 128, 50]    | dark green  |
| 2  | Vegetation       | [0, 220, 0]     | bright green|
| 3  | Building         | [255, 0, 0]     | red         |
| 4  | Road             | [100, 100, 100] | dark gray   |
| 5  | GuardRail        | [200, 200, 200] | bright gray |
| 6  | TrafficSign      | [255, 0, 255]   | pink        |
| 7  | TrafficLight     | [255, 255, 0]   | yellow      |
| 8  | Pole             | [128, 0, 255]   | violet      |
| 9  | Misc             | [255, 200, 150] | skin        |
| 10 | Truck            | [0, 128, 255]   | dark blue   |
| 11 | Car              | [0, 200, 255]   | bright blue |
| 12 | Van              | [255, 128, 0]   | orange      |
| 13 | Don't care       | [0, 0, 0]       | black       |

What challenges can you forsee when you will be working with such dataset?

### Graph creation

- Many times people do not know how to create a graph, using their data, so we will go through all the required steps here and create a graph dataset using PyG and our point clouds.
- As the documentation of PyG says you **do not** have to create the dataset in this way, but it's one of the ways and it's pretty helpful when you work with larger datasets, especially those which don't fit into your RAM.
- Even though our dataset is still small, we will create a "Large" Dataset object that is used for datasets that do not fit into memory so you could potentially use it for a real-life problem in the future.

#### Logic of the Dataset creation:

One can implement 4 important methods to create a PyG Dataset, however some can be skipped:

1. `raw_file_names()`: A list of files in the `raw` directory which needs to be found in order to skip the download.

2. `processed_file_names()`: A list of files in the `processed` directory which needs to be found in order to skip the processing.

3. `download()`: Downloads raw data into `raw` directory. *We will skip this step, since we already downloaded the data manually.*

4. `process()`: Processes raw data and saves it into the `processed` directory.

In [None]:
class TinyVKittiDataset(Dataset):

    def __init__(self, root: str, size, pre_transofrm: callable=NormalizeScale(), transform: callable=None, **kwargs):
       """TinyVKitti Dataset class.
       Args:
           root (str): Root directory where the dataset should be saved.
           size (int): The number of total VKITTI frame files in the root directory.
           transform (callable, optional): A function/transform that takes in a
               :class:`~torch_geometric.data.Data` or
               :class:`~torch_geometric.data.HeteroData` object and returns a
               transformed version.
               The data object will be transformed before every access.
               (default: :obj:`None`)
           pre_transform (callable, optional): A function/transform that takes in
               a :class:`~torch_geometric.data.Data` or
               :class:`~torch_geometric.data.HeteroData` object and returns a
               transformed version.
               The data object will be transformed before being saved to disk.
               (default: :obj:`None`)
       """
       pass

Let's decrease the number of points in the cloud, since otherwise it would be unmanagable to process in a reasonable amount of time on this machine.

- We will use **FixedPoint** of PyG transform that randomly samples $N$ points from a point cloud.

In [None]:
fixed_points_transform = FixedPoints(num=10_000, replace=False)

train_dataset_full = TinyVKittiDataset(root="data/train/", size=5, log=False)
train_dataset = TinyVKittiDataset(root="data/train/", size=5, transform=fixed_points_transform, log=False)
test_dataset = TinyVKittiDataset(root="data/test/", size=1, transform=fixed_points_transform, log=False)

In [None]:
visualize(
    point_cloud_graph1=train_dataset_full[0],
    point_cloud_graph2=train_dataset[0],
    show_both=True,
    name1="Original point cloud",
    name2="Randomly downsampled point cloud"
)

## PointNet++

We will re-implement the **[PointNet++](https://arxiv.org/abs/1706.02413)** architecture, which was a pioneering work for modeling point clouds directly, using principles usually applied by CNNs.

PointNet++ processes point clouds iteratively stacking together set abstraction (SA) layers. Each SA layer follows three phases: sampling, grouping and neighborhood feature aggregation.

1. The **sampling phase** implements a pooling scheme suitable for point clouds with potentially different sizes and selects a set of points from input points, which defines the centroids of local regions.

2. The **grouping phase** constructs a graph in which nearby points are connected. Typically, this is done via ball queries (which connects all points that are within a radius to the query point).

3. The **neighborhood feature aggregation phase** executes a Graph Neural Network layer that for each point aggregates information from its direct neighbors (given by the graph constructed in the previous phase). This steps encodes the local region patterns into a feature vector.

<div align="center">
    <img src="https://drive.google.com/uc?export=view&id=1lLDeG2tRywCT0vRxA5mZbco0flNgvlZR"/>
    <br>
</div>


### Sampling phase

Let's understand and visualize the sampling phase and cover some of the caveats.

Sampling (or downsamplig) phase uses the **Farthest Point Sampling** method.

Given an input point set $\{ \mathbf{p}_1, \ldots \mathbf{p}_n \}$, FPS iteratively selects a subset of points such that the sampled points are furthest apart. FPS has the following desired qualities:

- Compared with random sampling, this procedure is known to have better coverage of the entire point set.
-  In contrast to CNNs that scan the vector space agnostic of data distribution, FPS generates receptive fields in a data dependent manner.

PyG has a ready-to-use implementation of [`fps`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.fps), which takes in the positions of nodes and a sampling ratio, and returns the indices of nodes that have been sampled.

In [None]:
from torch_geometric.nn.pool import fps

graph = train_dataset[0]
fps_index = fps(graph.pos, ratio=0.5)

In [None]:
visualize(
    point_cloud_graph1=graph,
    point_cloud_graph2=graph,
    indices=[None, fps_index],
    show_both=True,
    name1="Original point cloud",
    name2="FPS downsampled point cloud"
)

### Grouping phase via dynamic graph generation

The input for this phase:
- a set of points of the size $N \times (d + C)$, where $d$ is the number of spatial coordinates and $C$ is the number of extra features, which is zero in our case.
- the coordinates of a set of centroids of size $N^{\prime} \times d$

The output are groups of point sets of size $N^{\prime} \times K \times (d + C)$, where

- each group corresponds to a local region.
- $K$ is the number of points in the neighborhood of centroid points. Note that $K$ varies across groups but the succeeding aggregation layer is able to convert a variable number of points into a fixed-length local region feature vector.

The locality of a region is measured by the corresponding distance.
- For example, in CNNs it's a Manhattan distance between the neighbouring pixels that corresponds to the kernel size.
- In general, our point clouds come from metric spaces with some predefined metric that allows to measure distances, angles and lengths. In the simplest case one can use Euclidean distance and ignore the curvature of the space.

## Exercise 2.

The authors of the paper offer 2 ways to group the sets and construct graphs made of the points of the set.

The first way is by using a simple **kNN** (k nearest neighbours) algorithm, which finds a fixed number of neighbouring points to a given point.

The second way is a **Ball query** method that finds all the points that are within a given radius of a query point, but in practice one usually limits the number of points to be no more than $K$.

1. For this exercise, please, create another small instance of the `TinyVKittiDataset`, using the `FixedPoint` transform with 10000 points:
    - **hint:** check how we did it above.

2. Get a graph object of type `Data`, using the 0th graph from this dataset:
    - Sample 10 centroid points by applying the `fps` method.
        - **hint:** calculate the required `fps` ratio to achieve this.
    - Group centroids and their neighbours into sets of graphs by using kNN search and Ball query methods:
        - For the kNN search you can use the [`knn`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.knn.html#torch_geometric.nn.pool.knn) method provided by PyG. Set the number of neighbours `k` equal to 32.
        - For the Ball query you can use the [`radius`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.radius.html#torch_geometric.nn.pool.radius) function, which is also a part of PyG package. Set the radius equal to 0.1 and the `max_num_neighbors` equal 32.
        - **hint:** use the `fps` indices for the centroids that you got from the previous step to choose the points for which you need to find the nearest neighbours. Ask me if in doubt!
    - Both of the methods return a `torch.Tensor` as the output, however we need to have source and destination nodes separately (source indices will be our centroid indices), so please use the following and fill in the missing values yourself:
    ```python
    dest_idx, src_idx = knn(x=..., y=..., k=32)
    ```

3. Visualize both methods, using the `visualize` function provided above.
    - Keep the `indices` argument equal to `[None, None]`
    - Construct 2 `edge_index` tensors for both methods.
        - **Important:** check the `dest_idx` tensor and note that when you indexed your centroid positions with the indices given by the `fps` and then applied the kNN and Ball query search, the resulting `dest_idx` tensor doesn't have the original indices of the sampled centroids, so you will need to get them back, using the indices that you received from `fps`.
        - Stack the `src_idx` tensor with the `dest_idx`that you got from the step above to construct the `edge_index` tensor for both methods. Use `torch.stack([src_idx, your_final_dest_idx], dim=0)` method to achieve this.
    - Create a tuple of `edge_indices`, using 2 `edge_index` tensors from the previous step and feed it to the `visualize` function.

4. Questions to answer:
    - What difference do you see between the visualizations of two methods?
    - Based on the difference, can you guess which method is more desirable and why?
    - What type of a graph did we get applying both procedures?
        - **hint:** recall the graph types from the colab about graph theory basics.
    - Play around with the Ball query radius and see what changes. Which radius would you use for which purposes?


In [None]:
from torch_geometric.nn.pool import radius, knn

#### YOUR CODE HERE ####


### Neighbourhood feature aggregation phase

- The input to this phase is $N^{\prime}$ local regions with their centroids that we sampled at step 1 and converted to graphs at step 2. The input has the dimensions $N^{\prime} \times K \times (d + C)$.
- So each local region is represented by its centroid and the accumulated features of the centroid's neighbours.
- The output is a feature tensor of shape $N^{\prime} \times (d + C)$
- The coordinates of local points in the region are translated into a local frame of the centroid coordinates, essentially performing the following operation:

$$
x_i^j = x_i^j - \hat{x}^j, \, \text{for} \, j \in \{1, ..., d\},
$$
  where $\hat{x}$ is the coordinate of the centroid and $d$ is the space dimensionality, which is 3 in our case.


The PointNet++ neighbourhood aggregation steps described above can be written as a simple **message passing scheme** defined via

$$
\mathbf{h}^{(\ell + 1)}_i = \max_{j \in \mathcal{N}(i)} \textrm{MLP} \left( \mathbf{h}_j^{(\ell)}, \mathbf{p}_j - \mathbf{p}_i \right)
$$
where
* $\mathbf{h}_i^{(\ell)} \in \mathbb{R}^d$ denotes the hidden features of point $i$ in layer $\ell$
* $\mathbf{p}_i \in \mathbb{R}^3$ denotes the position of point $i$.

**The PointNet++ can be viewed as the basic building block for local pattern learning. By using relative coordinates together with point features we can capture point-to-point relations in the local region.**

This time we will make use of the `MessagePassing` interface that helps us in **creating message passing graph neural networks** by automatically taking care of message propagation.
Here, we need to define its `message` function, as well as  which aggregation scheme to use, *e.g.*, `aggr="max"` (see [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html) for the accompanying tutorial):


### MessagePassing class

Let's have a closer look at the `MessagePassing` base class:

We only need to define the following:

- `MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)`, which defines:
    - the aggregation scheme to use ("add", "mean" or "max")
    - the flow direction of message passing (either "source_to_target" or "target_to_source")
    - the node_dim attribute indicates along which axis to propagate, default is -2, so we propagate between nodes of dimension `batch, num_nodes, num_node_features`.

- `MessagePassing.message(...)` constructs messages to the node $i$ for each edge.
    - the message is constructed from node $j$ to $i$, which in PyG terminology means `source_to_target` because $i$ is the the target node and $j$ is the source by convention.
    - $i$ is our centroid node and all nodes $j$ are the neighbouring nodes we want to run the `MessagePassing` from.

In [None]:
from torch_geometric.nn import MessagePassing

# +++++ WRITE TOGETHER +++++

class PointNetLayer(MessagePassing):
    def __init__(self, nn: torch.nn.Module):
        pass

    def forward(
            self, x: Tuple[torch.Tensor, ...], pos: Tuple[torch.Tensor, ...], edge_index: torch.Tensor
        ) -> torch.Tensor:
        """
        Performs the forward pass of the layer, by utilizing the self.propagate method of the base class.
        Args:
            x: A tuple of feature vectors representing the feature vectors of all neighbours and the centroids
            of shape ([num_neighbour_nodes, in_channels], [num_centroids, in_channels])
            pos: A tuple of positions of all neighbours and the centroids
            of shape ([num_neighbour_nodes, 3], [num_centroids, 3]
            edge_index: If the flow is `source_to_target`(default) then edge_index[0] are all the source
            (neighbouring) nodes and edge_index[1] are the destination (centroids) nodes

        As our PointNetLayer class inherits from the PyG MessagePassing parent class,
        we simply need to call the `propagate()` function which starts the
        message passing procedure: `message()` -> `aggregate()` -> `update()`.

        The MessagePassing class handles most of the logic for the implementation.
        To build custom GNNs, we only need to define our own `message()`,
        `aggregate()`, and `update()` functions (We use default aggregate() and update() here).

        """
        pass

    def message(self, x_j: torch.Tensor, pos_j: torch.Tensor, pos_i: torch.Tensor) -> torch.Tensor:
        """
        Creates a message from the neighbouring nodes to the centroids.

        The arguments can be a bit tricky to understand: `message()` can take
        any arguments that were initially passed to `propagate`. Additionally,
        we can differentiate destination nodes and source nodes by appending
        `_i` or `_j` to the variable name, e.g. for the node features `h`, we
        can use `h_i` and `h_j`.

        Tensors passed to method `propagate` can be mapped to the respective nodes `i` and `j`
        by appending `_i` or `_j` to the variable name, .e.g. pos_i and pos_j.

        The `message()` function constructs messages for each edge in the graph.
        The indexing of the original node features `h` (or other node variables) is handled under
        the hood by PyG.

        Args:
            x_j: defines the features of neighboring nodes, shape [num_edges, in_channels]
            pos_j: defines the position of neighboring nodes, shape [num_edges, 3]
            pos_i: defines the position of centroids, shape [num_edges, 3]

        Returns an output of the neural network that creates a message from the neighbouring nodes to the centroids
        """
        pass

### Local and Global Set Abstraction layers

- A Set Abstraction (SA) layer encompasses all the logic that we wrote above and implements 2 types of operations:
    - feature aggregation on the node level (local SA)
    - feature aggregation on the whole point cloud level (global SA).

Let's implement both of those layers together:

In [None]:
from torch_geometric.nn.pool import global_max_pool

# +++++ WRITE TOGETHER +++++

class SAModule(torch.nn.Module):
    def __init__(self, ratio, r, nn):
        pass

    def forward(self, x: Optional[torch.Tensor], pos: torch.Tensor, batch: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """
        Performs the forward pass of the SAModule that applies the following 3 operations:
        downsampling, grouping, message passing

        Args:
            x: Node features of shape [num_nodes, 3 + num_hidden_channels]
            pos: Node positions of shape [num_nodes, 3]
            batch:A batch tensor which assigns each node to a specific graph.. Shape [1, num_nodes]

        Returns updated node features for centroids, their positions and batch tensor that assigns each centroid to
        a particular graph in the batch:
        """
        pass


class GlobalSAModule(torch.nn.Module):
    def __init__(self, nn):
        pass

    def forward(self, x: torch.Tensor, pos: torch.Tensor, batch: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """
        Computes the graph level embeddings by utilizing the embeddings of the centroids and their positions

        Args:
            x: A tensor of centroid embeddings that we get from the SAModule of shape [num_centroids, F_x], where F_x
            is the dimensionality of the embeddings.
            pos: A tensor of positional embedgings (similar to x)
            batch:  A batch tensor which assigns each node to a specific graph. Shape [1, num_nodes]

        Returns a final global embedding for the each point cloud in the batch
        """
        pass

### Point Feature Propagation

In SA layer, the original point set is downsampled. How do we obtain a label for each point in the cloud that we originally want to get?

**The solution is to propagate features from the downsampled points to the original points, hence perform upsampling!**

- We will propagate point features of dimension $N_l \times (d +C)$ to $N_{l-1}$, where $N_l$ and $N_{l-1}$ are the point cloud sizes of the output and the input of set abstraction layer $l$, respectively.
- $N_l \leq N_{l-1}$ because of the downsampling.
- The propagation is achieved by feature interpolation. The features are computed as the inverse distance weighted average for k nearest neighbours. The value $k$ is set to 3 in the paper. PyG has a corresponding method for that, [knn_interpolate](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.unpool.knn_interpolate.html), that implements the formula below:
$$
f(y) = \frac{\sum_{i=1}^{k}w(x_i)f(x_i)}{\sum_{i=1}^kw(x_i)}, \, \text{where} \, w(x_i) = \frac{1}{d(\textbf{p}(y), \textbf{p}(x_i) )^2}, \\ \text{and} \,
\{x_1, ..., x_k\} \, \text{are the k nearest neighbours},
$$

where $y$ are the coordinates of the input $N_{l-1}$ points and $f(y)$ are their features.
- The interpolated features on $N_{l−1}$ points are then concatenated with a skipped connection given by the point features from the output of the corresponding set abstraction level.
- Then the concatenated features are passed through a neural network to get more descriptive embeddings.
- The process is repeated until the features are propagated to the original set of points.


In [None]:
from torch_geometric.nn.unpool import knn_interpolate

# +++++ WRITE TOGETHER +++++

class FPModule(torch.nn.Module):
    def __init__(self, k, nn):
        pass

    def forward(
            self, x: torch.Tensor, pos: torch.Tensor, batch: torch.Tensor, x_skip: torch.Tensor,
            pos_skip: torch.Tensor, batch_skip: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """
        Propagates features from the downsampled output of a SAModule back to the original points

        Args:
            x: A node feature tensor of shape [num_nodes/num_graphs, F_x]. We interpolate using these features!
            pos: A node position tensor of shape [num_nodes, 3]. We interpolate using these positions!
            batch: A batch tensor which assigns each node to a specific graph.

            x_skip: A node feature tensor that we get aftert a specific SA/FP layer. [num_nodes, F_x]
            pos_skip: A node position tensor with input node positions of the SA layer we want to upsample to.
            These are the point positions we are interested in restoring and propagating the features to. We interpolate
            using these positions!
            batch_skip: A batch tensor which assigns each node from the output of a specific SA/FP layer to a specific
            graph.

        Returns upsampled node features, positions and a new batch tensor.
        """
        pass

### PointNet++ for point Segmentation

Everything is ready to implement the full PointNet++ architectrure. Let's write a small network so it can be run in this colab, but one can always make it bigger.

#### MLP class of PyG

- PyG neatly implements an [MLP](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MLP.html#torch_geometric.nn.models.MLP) class, which we need for SA and FP layers.
- Instead of writing a MLP by hand we can use this class to compose any MLP with a ReLU nonlinearity as follows:

  ```python
  mlp = MLP([16, 32, 64, 128])
  ```

This notation above means that we created a 3-layers MLP neural network that has:

- Layer-1 with 16 input channels and 32 output channels
- Layer-2 with 32 input channelsand 64 output channels
- Layer-3 with 64 input channels and 128 output channels

One can also add dropout, a normalization function and other typical parameters.

In [None]:
from torch_geometric.nn import MLP

# +++++ WRITE TOGETHER +++++

class PointNetPP(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        seed(12345)
        pass

### Training

In [None]:
# CUDA or CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a model instance
model = PointNetPP(train_dataset.num_classes).to(device)

# Set the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Create a loss criterion
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Create data loaders
seed(12345)
fixed_points_transform = FixedPoints(num=10_000, replace=False)
train_dataset = TinyVKittiDataset(root="data/train/", size=5, transform=fixed_points_transform, log=False)
test_dataset = TinyVKittiDataset(root="data/test/", size=1, transform=fixed_points_transform, log=False)

train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [None]:
def train(model, loader):
    model.train()
    loss_all = 0

    for data in loader:
        data.to(device)  # send tensors to GPU/CPU
        optimizer.zero_grad()  # remove all grads from the previous step
        logits = model(data)  # run inference
        loss = criterion(logits, data.y)  # compute loss
        loss.backward()  # compute gradients with respect to each model parameter
        loss_all += loss.item() * data.num_graphs  # multiply loss by N graphs in a batch and add to the total loss
        optimizer.step()  # apply grads and update the weights
    return loss_all / len(train_loader.dataset)  # return avg loss per the whole dataset


@torch.no_grad()
def test(model, loader):
    model.eval()
    total_correct = 0
    total_nodes = 0
    for data in loader:
        data.to(device)
        logits = model(data)
        total_correct += logits.argmax(dim=1).eq(data.y).sum().item()
        total_nodes += data.num_nodes

    return total_correct / total_nodes



In [None]:
from IPython.display import Javascript  # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

# Training
for epoch in range(1, 200):
    loss = train(model, train_loader)
    train_acc = test(model, train_loader)
    test_acc = test(model, test_loader)
    print(f'Epoch: {epoch}, Train loss: {loss :.4f}, Train acc: {train_acc :.4f}, Test acc: {test_acc :.4f}')

