## Imports

In [15]:
import sys

sys.path.append("..")

from point_seg import transform_block
from point_seg import ShapeNetCoreLoaderInMemory
from configs import shapenetcore

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

## Constants

In [8]:
CATEGORY = "Airplane"
CONFIGS = shapenetcore.get_config()

## Load metadata

In [9]:
with open("/tmp/.keras/datasets/PartAnnotation/metadata.json") as json_file:
    metadata = json.load(json_file)

print(metadata)

LABELS = metadata[CATEGORY]["lables"]
COLORS = metadata[CATEGORY]["colors"]

{'Airplane': {'directory': '02691156', 'lables': ['wing', 'body', 'tail', 'engine'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Bag': {'directory': '02773838', 'lables': ['handle', 'body'], 'colors': ['blue', 'green']}, 'Cap': {'directory': '02954340', 'lables': ['panels', 'peak'], 'colors': ['blue', 'green']}, 'Car': {'directory': '02958343', 'lables': ['wheel', 'hood', 'roof'], 'colors': ['blue', 'green', 'red']}, 'Chair': {'directory': '03001627', 'lables': ['leg', 'arm', 'back', 'seat'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Earphone': {'directory': '03261776', 'lables': ['earphone', 'headband'], 'colors': ['blue', 'green']}, 'Guitar': {'directory': '03467517', 'lables': ['head', 'body', 'neck'], 'colors': ['blue', 'green', 'red']}, 'Knife': {'directory': '03624134', 'lables': ['handle', 'blade'], 'colors': ['blue', 'green']}, 'Lamp': {'directory': '03636649', 'lables': ['canopy', 'lampshade', 'base'], 'colors': ['blue', 'green', 'red']}, 'Laptop': {'directory': '0364280

## Visualization utils

In [10]:
def visualize_data(point_cloud, labels):
    df = pd.DataFrame(
        data={
            "x": point_cloud[:, 0],
            "y": point_cloud[:, 1],
            "z": point_cloud[:, 2],
            "label": labels,
        }
    )
    fig = plt.figure(figsize=(15, 10))
    ax = plt.axes(projection="3d")
    for index, label in enumerate(LABELS):
        c_df = df[df["label"] == label]
        try:
            ax.scatter(
                c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]
            )
        except IndexError:
            pass
    ax.legend()
    plt.show()


def visualize_single_point_cloud(point_clouds, label_clouds, idx):
    label_map = LABELS + ["none"]
    point_cloud = point_clouds[idx]
    label_cloud = label_clouds[idx]
    visualize_data(point_cloud, [label_map[np.argmax(label)] for label in label_cloud])

## Load data

In [6]:
data_loader = ShapeNetCoreLoaderInMemory(
    object_category=CATEGORY,
    n_sampled_points=CONFIGS.num_points,
)
data_loader.load_data()
_, val_dataset = data_loader.get_datasets(
    val_split=CONFIGS.val_split,
    batch_size=CONFIGS.batch_size,
)

Downloading data from https://github.com/soumik12345/point-cloud-segmentation/releases/download/v0.1/shapenet.zip


100%|███████████████████████████████████████| 4045/4045 [03:42<00:00, 18.15it/s]
100%|██████████████████████████████████████| 3694/3694 [00:06<00:00, 529.65it/s]
2021-10-29 14:50:29.360935: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Fetch model location

This needs to be updated after we release all the models on GitHub.

In [13]:
model_location = "gs://pointnet-segmentation/final_models"
category_model_location = tf.io.gfile.glob(model_location + f"/{CATEGORY}_*")[-1]
print(category_model_location)

gs://pointnet-segmentation/final_models/Airplane_211028-145411


## Load model and perform inference

In [None]:
segmentation_model = tf.keras.models.load_model(
    category_model_location,
    custom_objects={"OrthogonalRegularizer": transform_block.OrthogonalRegularizer},
)
val_image_batch, val_label_batch = val_dataset.take(1)
val_predictions = segmentation_model.predict(val_image_batch)

## Visualize the predictions

In [None]:
idx = np.random.choice(len(validation_batch[0]))
print(f"Index selected: {idx}")

# Plotting with ground-truth.
visualize_single_point_cloud(val_image_batch, val_label_batch, idx)

# Plotting with predicted labels.
visualize_single_point_cloud(val_image_batch, val_predictions, idx)