## Imports

In [None]:
!git clone https://github.com/soumik12345/point-cloud-segmentation -q

In [None]:
!pip install wandb ml_collections -qqq

In [None]:
import sys

sys.path.append("point-cloud-segmentation")

from point_seg import transform_block
from point_seg import ShapeNetCoreLoaderInMemory
from configs import shapenetcore

import json
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

## Constants

In [None]:
CATEGORY = Airplane #@param ["Airplane", "Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"] {type:"raw"}
CONFIGS = shapenetcore.get_config()

## Load metadata

In [None]:
metadata_path = tf.keras.utils.get_file(origin="https://github.com/soumik12345/point-cloud-segmentation/releases/download/v0.2/metadata.json")

with open(metadata_path) as json_file:
    metadata = json.load(json_file)

print(metadata)

LABELS = metadata[CATEGORY]["lables"]
COLORS = metadata[CATEGORY]["colors"]

## Visualization utils

In [None]:
def visualize_data(point_cloud, labels):
    df = pd.DataFrame(
        data={
            "x": point_cloud[:, 0],
            "y": point_cloud[:, 1],
            "z": point_cloud[:, 2],
            "label": labels,
        }
    )
    fig = plt.figure(figsize=(15, 10))
    ax = plt.axes(projection="3d")
    for index, label in enumerate(LABELS):
        c_df = df[df["label"] == label]
        try:
            ax.scatter(
                c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]
            )
        except IndexError:
            pass
    ax.legend()
    plt.show()


def visualize_single_point_cloud(point_clouds, label_clouds, idx):
    label_map = LABELS + ["none"]
    point_cloud = point_clouds[idx]
    label_cloud = label_clouds[idx]
    visualize_data(point_cloud, [label_map[np.argmax(label)] for label in label_cloud])

## Load data

In [None]:
data_loader = ShapeNetCoreLoaderInMemory(
    object_category=CATEGORY,
    n_sampled_points=CONFIGS.num_points,
)
data_loader.load_data()
_, val_dataset = data_loader.get_datasets(
    val_split=CONFIGS.val_split,
    batch_size=CONFIGS.batch_size,
)

## Fetch model location

This needs to be updated after we release all the models on GitHub.

In [None]:
model_location = "gs://pointnet-segmentation/final_models"
category_model_location = tf.io.gfile.glob(model_location + f"/{CATEGORY}_*")[-1]
print(category_model_location)

## Load model and perform inference

In [None]:
segmentation_model = tf.keras.models.load_model(
    category_model_location,
    custom_objects={"OrthogonalRegularizer": transform_block.OrthogonalRegularizer},
)
val_data_batch = next(iter(val_dataset))
val_predictions = segmentation_model.predict(val_data_batch[0])

## Visualize the predictions

In [None]:
idx = np.random.choice(len(val_data_batch[0]))
print(f"Index selected: {idx}")

# Plotting with ground-truth.
print("***********Ground-truth***********")
visualize_single_point_cloud(val_data_batch[0], val_data_batch[1], idx)

# Plotting with predicted labels.
print("***********Predicted***********")
visualize_single_point_cloud(val_data_batch[0], val_predictions, idx)