## Imports

In [1]:
import os
import json
import random
import numpy as np
import pandas as pd

import tensorflow as tf

from typing import List
from tqdm import tqdm
from glob import glob

## Data loading

In [2]:
dataset_url = "https://github.com/soumik12345/point-cloud-segmentation/releases/download/v0.1/shapenet.zip"

dataset_path = tf.keras.utils.get_file(
    fname="shapenet.zip",
    origin=dataset_url,
    cache_subdir="datasets",
    hash_algorithm="auto",
    extract=True,
    archive_format="auto",
    cache_dir="datasets",
)

In [3]:
with open("/tmp/.keras/datasets/PartAnnotation/metadata.json") as json_file:
    metadata = json.load(json_file)

metadata

{'Airplane': {'directory': '02691156',
  'lables': ['wing', 'body', 'tail', 'engine'],
  'colors': ['blue', 'green', 'red', 'pink']},
 'Bag': {'directory': '02773838',
  'lables': ['handle', 'body'],
  'colors': ['blue', 'green']},
 'Cap': {'directory': '02954340',
  'lables': ['panels', 'peak'],
  'colors': ['blue', 'green']},
 'Car': {'directory': '02958343',
  'lables': ['wheel', 'hood', 'roof'],
  'colors': ['blue', 'green', 'red']},
 'Chair': {'directory': '03001627',
  'lables': ['leg', 'arm', 'back', 'seat'],
  'colors': ['blue', 'green', 'red', 'pink']},
 'Earphone': {'directory': '03261776',
  'lables': ['earphone', 'headband'],
  'colors': ['blue', 'green']},
 'Guitar': {'directory': '03467517',
  'lables': ['head', 'body', 'neck'],
  'colors': ['blue', 'green', 'red']},
 'Knife': {'directory': '03624134',
  'lables': ['handle', 'blade'],
  'colors': ['blue', 'green']},
 'Lamp': {'directory': '03636649',
  'lables': ['canopy', 'lampshade', 'base'],
  'colors': ['blue', 'green

In [4]:
points_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points".format(
    metadata["Airplane"]["directory"]
)
labels_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points_label".format(
    metadata["Airplane"]["directory"]
)
LABELS = metadata["Airplane"]["lables"]
COLORS = metadata["Airplane"]["colors"]
VAL_SPLIT = 0.2
N_SAMPLE_POINTS = 1024

## Filtering the files that have labels

In [5]:
points_files_with_keys = set()
points_files = glob(os.path.join(points_dir, "*.pts"))


for point_file in tqdm(points_files):
    file_id = point_file.split("/")[-1].split(".")[0]
    label_data = {}

    for label in LABELS:
        label_file = os.path.join(labels_dir, label, file_id + ".seg")
        if os.path.exists(label_file):
            label_data[label] = 0  # Dummy assignment only used as a placeholder.
    try:
        label_data = np.vstack(tuple([label_data[key] for key in LABELS]))
        points_files_with_keys.add(point_file)
    except:
        continue

points_files_with_keys = list(points_files_with_keys)
len(points_files_with_keys)

100%|████████████████████████████████████| 4045/4045 [00:00<00:00, 16213.40it/s]


3694

## Prepare dataset utilities

In [6]:
def process_single_point_file(point_filepath: str):
    # Load the point cloud from disk.
    point_filepath = point_filepath.numpy().decode("utf-8")
    point_cloud = np.loadtxt(point_filepath)

    # Parse the file-id.
    file_id = point_filepath.split("/")[-1].split(".")[0]
    label_data, num_labels = {}, 0

    # Parse the labels.
    for label in LABELS:
        label_file = os.path.join(labels_dir, label, file_id + ".seg")
        label_data[label] = np.loadtxt(label_file).astype("float32")
        num_labels = len(label_data[label])

    label_map = ["none"] * num_labels
    for label in LABELS:
        for i, data in enumerate(label_data[label]):
            label_map[i] = label if data == 1 else label_map[i]
    label_data = np.vstack(tuple([label_data[key] for key in LABELS]))
    label_cloud = label_data.reshape(label_data.shape[1], label_data.shape[0])

    # Sample `N_SAMPLE_POINTS` from the point and label clouds randomly.
    sampled_point_cloud, sampled_label_cloud = random_sampler(point_cloud, label_cloud)

    # Normalizing point cloud.
    normalized_point_cloud = sampled_point_cloud - np.mean(sampled_point_cloud, axis=0)
    normalized_point_cloud /= np.max(np.linalg.norm(normalized_point_cloud, axis=1))

    return normalized_point_cloud, sampled_label_cloud


def random_sampler(point_cloud: np.ndarray, label_cloud: np.ndarray):
    n_points = len(point_cloud)

    # Randomly sampling respective indices.
    sampled_indices = random.sample(list(range(n_points)), N_SAMPLE_POINTS)

    # Sampling points corresponding to sampled indices.
    sampled_point_cloud = np.array([point_cloud[i] for i in sampled_indices])

    # Sampling corresponding one-hot encoded labels.
    sampled_label_cloud = np.array([label_cloud[i] for i in sampled_indices])

    return sampled_point_cloud, sampled_label_cloud


def tf_process_point_file(point_filepath: str):
    data_tuple = tf.py_function(
        process_single_point_file, [point_filepath], [tf.float64, tf.float32]
    )
    return data_tuple


def augment(point_cloud_batch, label_cloud_batch):
    # Jitter point and label clouds.
    noise = tf.random.uniform(
        tf.shape(label_cloud_batch), -0.005, 0.005, dtype=tf.float64
    )
    point_cloud_batch += noise[:, :, :3]
    label_cloud_batch += tf.cast(noise, tf.float32)

    return point_cloud_batch, label_cloud_batch


def prepare_dataset(
    point_filepaths: List[str], is_train: bool = True, batch_size: int = 16
):
    point_files_ds = tf.data.Dataset.from_tensor_slices(point_filepaths)
    if is_train:
        point_files_ds = point_files_ds.shuffle(batch_size * 100)

    point_ds = point_files_ds.map(
        tf_process_point_file, num_parallel_calls=tf.data.AUTOTUNE
    )
    point_ds = point_ds.batch(batch_size)
    if is_train:
        point_ds = point_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)

    return point_ds

## Create `tf.data.Dataset` objects

In [7]:
point_files_ds = tf.data.Dataset.from_tensor_slices(points_files_with_keys)
single_pcloud_file = next(iter(point_files_ds))
single_pcloud_file

2021-10-05 16:39:02.238126: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


<tf.Tensor: shape=(), dtype=string, numpy=b'/tmp/.keras/datasets/PartAnnotation/02691156/points/ecc50d702133b1531e9da99095f71c63.pts'>

In [8]:
single_pcloud_file = next(iter(point_files_ds))
point_cloud, label_cloud = process_single_point_file(single_pcloud_file)
point_cloud.shape, label_cloud.shape

((1024, 3), (1024, 4))

In [9]:
point_ds = point_files_ds.map(
    tf_process_point_file, num_parallel_calls=tf.data.AUTOTUNE
)

for point_cloud, label_cloud in point_ds.take(5):
    print(point_cloud.shape, label_cloud.shape)

2021-10-05 16:39:02.384411: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


(1024, 3) (1024, 4)
(1024, 3) (1024, 4)
(1024, 3) (1024, 4)
(1024, 3) (1024, 4)
(1024, 3) (1024, 4)


In [10]:
split_index = int(len(points_files_with_keys) * (1 - VAL_SPLIT))
train_point_cloud_files = points_files_with_keys[:split_index]
val_point_cloud_files = points_files_with_keys[split_index:]

print(f"Total training files: {len(train_point_cloud_files)}.")
print(f"Total validation files: {len(val_point_cloud_files)}.")

Total training files: 2955.
Total validation files: 739.


In [11]:
train_ds = prepare_dataset(train_point_cloud_files)
validation_ds = prepare_dataset(val_point_cloud_files, is_train=False)

In [12]:
train_ds.element_spec, validation_ds.element_spec

((TensorSpec(shape=<unknown>, dtype=tf.float64, name=None),
  TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)),
 (TensorSpec(shape=<unknown>, dtype=tf.float64, name=None),
  TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)))

In [13]:
for single_batch in train_ds.take(1):
    break
    
single_batch[0].shape, single_batch[1].shape

(TensorShape([16, 1024, 3]), TensorShape([16, 1024, 4]))