<a href="https://colab.research.google.com/github/soumik12345/point-cloud-segmentation/blob/inference/notebooks/train_gpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/soumik12345/point-cloud-segmentation
!pip install -qqq wandb ml_collections

In [None]:
import sys

sys.path.append("point-cloud-segmentation")

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
import os

import wandb
import wandb.keras
from datetime import datetime

from tensorflow.keras import optimizers, callbacks
from tensorflow.keras import mixed_precision

from point_seg import TFRecordLoader, ShapeNetCoreLoaderInMemory
from point_seg import models, utils

In [None]:
#@title Configs
#@markdown Get your `wandb_api_key` from https://wandb.ai/authorize.
wandb_api_key = "" #@param {type:"string"}
object_category = "Bag" #@param ["Airplane", "Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"] {type:"raw"}
num_points = 1024 #@param {type:"integer"}
batch_size = 32 #@param {type:"integer"}
val_split = 0.2 #@param {type:"number"}
epochs = 60 #@param {type:"integer"}
initial_lr = 1e-3 #@param {type:"number"}
drop_every = 10 #@param {type:"integer"}
decay_factor = 0.5 #@param {type:"number"}

config_dict = {
	"object_category": object_category,
	"num_points": num_points,
	"batch_size": batch_size,
	"val_split": val_split,
	"epochs": epochs,
	"initial_lr": initial_lr,
	"lr_drop_epoch": drop_every,
	"decay_factor": decay_factor
}

In [None]:
timestamp = datetime.utcnow().strftime("%y%m%d-%H%M%S")
strategy = utils.initialize_device()
batch_size = 32 * strategy.num_replicas_in_sync

In [None]:
wandb.init(
    project='pointnet_shapenet_core',
    name=f"{object_category}_{timestamp}",
    entity="pointnet",
    config=config_dict,
)

In [None]:
# Apply mixed-precision policy [OPTIONAL]
mixed_precision.set_global_policy("mixed_float16")
policy = mixed_precision.global_policy()

In [None]:
data_loader = ShapeNetCoreLoaderInMemory(
    object_category=object_category,
    n_sampled_points=num_points,
)
data_loader.load_data()
train_dataset, val_dataset = data_loader.get_datasets(
    val_split=val_split,
    batch_size=batch_size,
)

In [None]:
lr_scheduler = utils.StepDecay(initial_lr, drop_every, decay_factor)
lr_callback = callbacks.LearningRateScheduler(
    lambda epoch: lr_scheduler(epoch), verbose=True
)

# Tensorboard Callback
logs_dir = os.path.join(
    "logs", f"{object_category}_{timestamp}"
)
tb_callback = callbacks.TensorBoard(log_dir=logs_dir)

# ModelCheckpoint Callback
checkpoint_path = os.path.join(
    "training_checkpoints",
    f"{object_category}_{timestamp}.h5",
)
checkpoint_callback = callbacks.ModelCheckpoint(
    filepath=checkpoint_path, save_best_only=True, save_weights_only=True,
)

callback_list = [
    tb_callback,
    checkpoint_callback,
    lr_callback,
    wandb.keras.WandbCallback()
]

In [None]:
with strategy.scope():
    optimizer = optimizers.Adam(learning_rate=initial_lr)
    _, y = next(iter(train_dataset))
    num_classes = y.shape[-1]
    model = models.get_shape_segmentation_model(num_points, num_classes)

model.compile(
    optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
)

In [None]:
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    callbacks=callback_list,
)