In [1]:
# import time
# import copy
import wandb
import torch
import warnings
warnings.filterwarnings('ignore')
# import torch_geometric
import os

from torch import nn, optim
from torch.nn import functional as F
from torch_geometric import nn as gnn
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import ModelNet
# from IPython.display import display, clear_output
# from torch_geometric.utils import remove_self_loops
# from sklearn.model_selection import train_test_split
from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius
# from torch_geometric.transforms import Compose, RandomRotate, SamplePoints, KNNGraph, NormalizeScale
import torch_geometric.transforms as T


import numpy as np
import matplotlib.pyplot as plt

from model import GNN, PNET, PointNet, PointViG, PointNet2

import random
from glob import glob
from tqdm.auto import tqdm

import wandb


torch.manual_seed(seed=42)

<torch._C.Generator at 0x1fb267ad630>

In [2]:
wandb_project = "pyg-point-cloud" #@param {"type": "string"}
wandb_run_name = "final-experiment/modelnet10/2" #@param {"type": "string"}

wandb.init(project=wandb_project, name=wandb_run_name, job_type="baseline-train")

# Set experiment configs to be synced with wandb
config = wandb.config
config.modelnet_dataset_alias = "ModelNet10" #@param ["ModelNet10", "ModelNet40"] {type:"raw"}

config.seed = 4242 #@param {type:"number"}
random.seed(config.seed)
torch.manual_seed(config.seed)

config.sample_points = 2048 #@param {type:"slider", min:256, max:4096, step:16}

config.categories = sorted([
    x.split(os.sep)[-2]
    for x in glob(os.path.join(
        config.modelnet_dataset_alias, "raw", '*', ''
    ))
])

config.batch_size = 16 #@param {type:"slider", min:4, max:128, step:4}
config.num_workers = 6 #@param {type:"slider", min:1, max:10, step:1}

config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(config.device)

config.set_abstraction_ratio_1 = 0.748 #@param {type:"slider", min:0.1, max:1.0, step:0.01}
config.set_abstraction_radius_1 = 0.4817 #@param {type:"slider", min:0.1, max:1.0, step:0.01}
config.set_abstraction_ratio_2 = 0.3316 #@param {type:"slider", min:0.1, max:1.0, step:0.01}
config.set_abstraction_radius_2 = 0.2447 #@param {type:"slider", min:0.1, max:1.0, step:0.01}
config.dropout = 0.1 #@param {type:"slider", min:0.1, max:1.0, step:0.1}

config.learning_rate = 1e-4 #@param {type:"number"}
config.epochs = 10 #@param {type:"slider", min:1, max:100, step:1}
config.num_visualization_samples = 20 #@param {type:"slider", min:1, max:100, step:1}

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msaleemkheralden[0m ([33mtechnionsaleem[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
pre_transform = T.NormalizeScale()
transform = T.SamplePoints(config.sample_points)


train_dataset = ModelNet(
    'ModelNet/',
	name='10',
    train=True,
    transform=transform,
    pre_transform=pre_transform
)
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers
)

val_dataset = ModelNet(
    'ModelNet/',
	name='10',
    train=False,
    transform=transform,
    pre_transform=pre_transform
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers
)

random_indices = random.sample(
    list(range(len(val_dataset))),
    config.num_visualization_samples
)
vizualization_loader = DataLoader(
    [val_dataset[idx] for idx in random_indices],
    batch_size=1,
    shuffle=False,
    num_workers=config.num_workers
)

In [4]:
class SetAbstraction(torch.nn.Module):
    def __init__(self, ratio, r, nn):
        super().__init__()
        self.ratio = ratio
        self.r = r
        self.conv = PointNetConv(nn, add_self_loops=False)

    def forward(self, x, pos, batch):
        idx = fps(pos, batch, ratio=self.ratio)
        row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
                          max_num_neighbors=64)
        edge_index = torch.stack([col, row], dim=0)
        x_dst = None if x is None else x[idx]
        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return x, pos, batch


In [5]:

class GlobalSetAbstraction(torch.nn.Module):
    def __init__(self, nn):
        super().__init__()
        self.nn = nn

    def forward(self, x, pos, batch):
        x = self.nn(torch.cat([x, pos], dim=1))
        x = global_max_pool(x, batch)
        pos = pos.new_zeros((x.size(0), 3))
        batch = torch.arange(x.size(0), device=batch.device)
        return x, pos, batch


In [6]:

class PointNet2(torch.nn.Module):
    def __init__(
        self,
        set_abstraction_ratio_1, set_abstraction_ratio_2,
        set_abstraction_radius_1, set_abstraction_radius_2, dropout
    ):
        super().__init__()

        # Input channels account for both `pos` and node features.
        self.sa1_module = SetAbstraction(
            set_abstraction_ratio_1,
            set_abstraction_radius_1,
            MLP([3, 64, 64, 128])
        )
        self.sa2_module = SetAbstraction(
            set_abstraction_ratio_2,
            set_abstraction_radius_2,
            MLP([128 + 3, 128, 128, 256])
        )
        self.sa3_module = GlobalSetAbstraction(MLP([256 + 3, 256, 512, 1024]))

        self.mlp = MLP([1024, 512, 256, 10], dropout=dropout, norm=None)

    def forward(self, data):
        sa0_out = (data.x, data.pos, data.batch)
        sa1_out = self.sa1_module(*sa0_out)
        sa2_out = self.sa2_module(*sa1_out)
        sa3_out = self.sa3_module(*sa2_out)
        x, pos, batch = sa3_out

        return self.mlp(x).log_softmax(dim=-1)

## Training PointNet++ and Logging Metrics on Weights & Biases

In [7]:
# Define PointNet++ model.
model = PointNet2(
    config.set_abstraction_ratio_1,
    config.set_abstraction_ratio_2,
    config.set_abstraction_radius_1,
    config.set_abstraction_radius_2,
    config.dropout
).to(device)

# Define Optimizer
optimizer = torch.optim.Adam(
    model.parameters(), lr=config.learning_rate
)


In [8]:

def train_step(epoch):
    """Training Step"""
    model.train()
    epoch_loss, correct = 0, 0
    num_train_examples = len(train_loader)
    
    progress_bar = tqdm(
        range(num_train_examples),
        desc=f"Training Epoch {epoch}/{config.epochs}"
    )
    data_iter = iter(train_loader)
    for batch_idx in progress_bar:
        data = next(data_iter).to(device)
        
        optimizer.zero_grad()
        prediction = model(data)
        loss = F.nll_loss(prediction, data.y)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        correct += prediction.max(1)[1].eq(data.y).sum().item()
    
    epoch_loss = epoch_loss / num_train_examples
    epoch_accuracy = correct / len(train_loader.dataset)
    
    wandb.log({
        "Train/Loss": epoch_loss,
        "Train/Accuracy": epoch_accuracy
    })


def val_step(epoch):
    """Validation Step"""
    model.eval()
    epoch_loss, correct = 0, 0
    num_val_examples = len(val_loader)
    
    progress_bar = tqdm(
        range(num_val_examples),
        desc=f"Validation Epoch {epoch}/{config.epochs}"
    )
    data_iter = iter(val_loader)
    for batch_idx in progress_bar:
        data = next(data_iter).to(device)
        
        with torch.no_grad():
            prediction = model(data)
        
        loss = F.nll_loss(prediction, data.y)
        epoch_loss += loss.item()
        correct += prediction.max(1)[1].eq(data.y).sum().item()
    
    epoch_loss = epoch_loss / num_val_examples
    epoch_accuracy = correct / len(val_loader.dataset)
    
    wandb.log({
        "Validation/Loss": epoch_loss,
        "Validation/Accuracy": epoch_accuracy
    })


def visualize_evaluation(table, epoch):
    """Visualize validation result in a Weights & Biases Table"""
    point_clouds, losses, predictions, ground_truths, is_correct = [], [], [], [], []
    progress_bar = tqdm(
        range(config.num_visualization_samples),
        desc=f"Generating Visualizations for Epoch {epoch}/{config.epochs}"
    )
    
    for idx in progress_bar:
        data = next(iter(vizualization_loader)).to(device)
        
        with torch.no_grad():
            prediction = model(data)
        
        point_clouds.append(
            wandb.Object3D(torch.squeeze(data.pos, dim=0).cpu().numpy())
        )
        losses.append(F.nll_loss(prediction, data.y).item())
        predictions.append(config.categories[int(prediction.max(1)[1].item())])
        ground_truths.append(config.categories[int(data.y.item())])
        is_correct.append(prediction.max(1)[1].eq(data.y).sum().item())
    
    table.add_data(
        epoch, point_clouds, losses, predictions, ground_truths, is_correct
    )
    return table


def save_checkpoint(epoch):
    """Save model checkpoints as Weights & Biases artifacts"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, "checkpoint.pt")
    
    artifact_name = wandb.util.make_artifact_name_safe(
        f"{wandb.run.name}-{wandb.run.id}-checkpoint"
    )
    
    checkpoint_artifact = wandb.Artifact(artifact_name, type="checkpoint")
    checkpoint_artifact.add_file("checkpoint.pt")
    wandb.log_artifact(
        checkpoint_artifact, aliases=["latest", f"epoch-{epoch}"]
    )

table = wandb.Table(
    columns=[
        "Epoch",
        "Point-Clouds",
        "Losses",
        "Predicted-Classes",
        "Ground-Truth",
        "Is-Correct"
    ]
)
for epoch in range(1, config.epochs + 1):
    train_step(epoch)
    val_step(epoch)
    # visualize_evaluation(table, epoch)
    # save_checkpoint(epoch)
wandb.log({"Evaluation": table})

wandb.finish()


Training Epoch 1/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 1/10:   0%|          | 0/57 [00:00<?, ?it/s]

Training Epoch 2/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 2/10:   0%|          | 0/57 [00:00<?, ?it/s]

Training Epoch 3/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 3/10:   0%|          | 0/57 [00:00<?, ?it/s]

Training Epoch 4/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 4/10:   0%|          | 0/57 [00:00<?, ?it/s]

Training Epoch 5/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 5/10:   0%|          | 0/57 [00:00<?, ?it/s]

Training Epoch 6/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 6/10:   0%|          | 0/57 [00:00<?, ?it/s]

Training Epoch 7/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 7/10:   0%|          | 0/57 [00:00<?, ?it/s]

Training Epoch 8/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 8/10:   0%|          | 0/57 [00:00<?, ?it/s]

Training Epoch 9/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 9/10:   0%|          | 0/57 [00:00<?, ?it/s]

Training Epoch 10/10:   0%|          | 0/250 [00:00<?, ?it/s]

Validation Epoch 10/10:   0%|          | 0/57 [00:00<?, ?it/s]

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Train/Accuracy,▁▄▅▆▇▇▇███
Train/Loss,█▅▄▃▂▂▂▁▁▁
Validation/Accuracy,▁▁▁▆▃▄▁▃█▃
Validation/Loss,▆▇█▃▇▄▇▅▁▄

0,1
Train/Accuracy,0.82561
Train/Loss,0.51886
Validation/Accuracy,0.31278
Validation/Loss,1.88544
