In [None]:
#@title {vertical-output: true}

! pip install pytorch-lightning

Imports and Utility functions

In [None]:
#@title {vertical-output: true}
import torch
import torch.nn as nn
from torch import nn
import torch.nn.parallel
import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import torch.nn.functional as F

from torch.autograd import Variable
import numpy as np
import numpy as np
import pytorch_lightning as pl

import plotly
import plotly.graph_objects as go
from plotly.graph_objs import Layout
from plotly.graph_objs import Layout
from plotly.subplots import make_subplots

from typing import Union, Sequence, Optional
from torch.nn.functional import nll_loss
import torchmetrics


def plot3d(
    x: Union[np.ndarray, torch.Tensor], c: Union[np.ndarray, torch.Tensor]
) -> None:
    """
    Plot the function c over the point cloud x
    """
    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=x[:, 0],
                y=x[:, 1],
                z=x[:, 2],
                mode="markers",
                marker=dict(color=c, colorscale="viridis", size=5, showscale=True),
            )
        ],
        layout=Layout(scene=dict(aspectmode="data")),
    )
    fig.show()

def plot3d_shapes(
    x: Sequence[Union[np.ndarray, torch.Tensor]], c: Sequence[Union[np.ndarray, torch.Tensor]], subplot_titles: Optional[Sequence[str]] = None
) -> None:
    """
    Plot the function c over the point cloud x
    """
    fig = make_subplots(
        rows=1,
        cols=len(x),
        specs=[[{"is_3d": True}] * len(x)],
        horizontal_spacing=0,
        vertical_spacing=0,
        subplot_titles=subplot_titles if subplot_titles is not None else None,

    )

    myscene = dict(
        camera=dict(
            up=dict(x=0, y=1, z=0),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=-0.25, y=0.25, z=2.75),
        ),
        aspectmode="data",
    )

    for i, (points, color) in enumerate(zip(x, c)):
      fig.add_trace(
          go.Scatter3d(
                x=points[:, 0],
                y=points[:, 1],
                z=points[:, 2],
                mode="markers",
                marker=dict(color=color, colorscale="viridis", size=5, showscale=True),
            ),
            row=1,
            col=i+1
      )

    for i in range(len(x)):
      fig["layout"][f"scene{i+1}"].update(myscene)

    fig.update_layout(margin=dict(l=0, r=0, b=0, t=30))

    fig.show()



We just get some data

In [None]:
#@title {vertical-output: true}

!wget -O 'data.zip' "https://www.dropbox.com/sh/qzh9bo3rbpd0k7d/AAA_dUzVFHBBqLrS8qslYCMJa?dl=1"
!unzip data.zip

We define an appropriate data loader. Nothing special: it will provide us point clouds and the mask of their regions.

In [None]:
#@title {vertical-output: true}

from typing import Dict

class ShapeLocalizationDataset(Dataset):
  def __init__(self, shapes_data_name: str , region_data_name: str):
    super().__init__()
    self.shapes_data_name = shapes_data_name
    self.region_data_name = region_data_name

    self.shapes = np.load(shapes_data_name).astype(np.float32)

    self.region_idxs = np.loadtxt(region_data_name).astype(np.int64) - 1
    self.mask = np.zeros(self.shapes.shape[1], dtype=np.int64)
    self.mask[self.region_idxs] = 1

  def __len__(self) -> int:
    return self.shapes.shape[0]

  def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
    """ Returns the points of a point cloud with shape [n, 3] and the region
    that must be localized with shape [n]
    """
    shape = self.shapes[idx]
    return {
        'id': idx,
        'points': shape,
        'mask': self.mask
    }

  def __repr__(self):
    return f"ShapeLocalizationDataset(shapes_data_name='{self.shapes_data_name}', region_data_name='{self.region_data_name}')"

In [None]:
#@title {vertical-output: true}

print(ShapeLocalizationDataset('2K_shapes_train.npy', 'head_idxs4template1K.txt'))

In [None]:
#@title Explore data { run: "auto" } {vertical-output: true}
shapes_data = "12k_shapes_train.npy" #@param ["12k_shapes_test.npy", "2K_shapes_train.npy", "12k_shapes_train.npy", "200_shapes_test.npy"]
region_data = "head_idxs4template1K.txt" #@param ["head_idxs4template1K.txt", "belly_idxs4template1K.txt"]
sample_idx = 33 #@param {type:"slider", min:0, max:100, step:1}


dataset = ShapeLocalizationDataset(shapes_data_name=shapes_data, region_data_name=region_data)

sample = dataset[sample_idx]
plot3d(sample['points'], sample['mask'])

# Setup data loaders

In [None]:
#@title {vertical-output: true}
# Decide with data to use
shapes_data_train = "12k_shapes_train.npy"
shapes_data_test = "12k_shapes_test.npy"
region_data = "head_idxs4template1K.txt"

We create the data loader for training and test

In [None]:
#@title {vertical-output: true}

train_dataset = ShapeLocalizationDataset(shapes_data_name=shapes_data_train, region_data_name=region_data)
test_dataset = ShapeLocalizationDataset(shapes_data_name=shapes_data_test, region_data_name=region_data)

# Hyperparameters
batch_size = 16
num_workers = 2  # number of parallel processes to use to prepare batches

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)

Let's visualize how a data batch looks like

In [None]:
#@title {vertical-output: true}
for batch in train_dataloader:
  print(batch)
  break

In [None]:
#@title {vertical-output: true}

class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat


class PointNetDenseCls(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseCls, self).__init__()
        self.k = k
        self.feature_transform=feature_transform
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        x = x.transpose(2,1).contiguous()
        x = F.log_softmax(x.view(-1,self.k), dim=-1)  # do not use crossentropy
        x = x.view(batchsize, n_pts, self.k)
        return x, trans, trans_feat

In [None]:
#@title {vertical-output: true}

class RegionLocalizationModule(pl.LightningModule):
  def __init__(self):
    super().__init__()
    k = 2
    self.pointnet = PointNetDenseCls(k=k, feature_transform=True)

    self.train_accuracy = torchmetrics.Accuracy(task='multiclass',num_classes=k)
    self.test_accuracy = torchmetrics.Accuracy(task='multiclass',num_classes=k)

  def forward(self, points: torch.Tensor) -> torch.Tensor:
    """
    Defines the behaviour in the forward pass

    Args:
      points: shape points with shape [batch, xyz, num_points]

    Returns:
      probability distributions over the classes for each point [batch, xyz, k]
    """
    points = points.transpose(1, 2)
    out, _, _ = self.pointnet(points)
    return out

  def training_step(self, batch, batch_idx):
    """
    Defines the training logic
    """
    points = batch['points']
    y = batch['mask']

    y_pred = self(points)


    y_pred = y_pred.transpose(1, 2)
    loss = nll_loss(y_pred, y)

    self.train_accuracy(y_pred.exp(), y)

    self.log_dict({'train_loss': loss, 'train_acc': self.train_accuracy}, on_step=True, on_epoch=True, prog_bar=True)

    return loss


  def test_step(self, batch, batch_idx):
    """
    Defines the training logic
    """
    points = batch['points']
    y = batch['mask']

    y_pred = self(points)


    y_pred = y_pred.transpose(1, 2)
    loss = nll_loss(y_pred, y)

    self.test_accuracy(y_pred.exp(), y)

    self.log_dict({'test_loss': loss, 'test_acc': self.test_accuracy}, on_epoch=True, prog_bar=True)

    return loss

  def configure_optimizers(self):
    """
    Configure optimizers
    """
    return torch.optim.AdamW(self.parameters())

In [None]:
#@title {vertical-output: true}
# Instantiate the model
model = RegionLocalizationModule()

# Instantiate the trainer
trainer = pl.Trainer(accelerator="auto", max_steps=5000, max_epochs=3)

In [None]:
#@title {vertical-output: true}
results = trainer.test(model, test_dataloader)

In [None]:
#@title {vertical-output: true}
# Train the model
trainer.fit(model, train_dataloader)

In [None]:
#@title {vertical-output: true}
results = trainer.test(model, test_dataloader)

In [None]:
#@title Explore predictions { run: "auto" }
shapes_data = "12k_shapes_test.npy" #@param ["12k_shapes_test.npy", "200_shapes_test.npy"]
region_data = "head_idxs4template1K.txt" #@param ["head_idxs4template1K.txt", "belly_idxs4template1K.txt"]
dataset = ShapeLocalizationDataset(shapes_data_name=shapes_data, region_data_name=region_data)

sample_idx = 920 #@param {type:"slider", min:0, max:1000, step:1}
permute_points = True #@param {type:"boolean"}
sample_points = True #@param {type:"boolean"}
number_of_sampled_points = 494 #@param {type:"slider", min:5, max:1000, step:1}

model = model.cpu()
model.eval()

if sample_idx > len(dataset):
  print(f'Sample idx over the selected dataset length. Setted to : {len(dataset)}')
  sample_idx = len(dataset) - 1

# [1, n_points, xyz]
points = torch.from_numpy(dataset[sample_idx]['points'])[None, ...]

y = dataset[sample_idx]['mask']


if permute_points:
  perm_points = torch.randperm(1000)
  points = points[:, perm_points, :]
  y = y[perm_points]

if sample_points:
  points_to_keep = torch.randperm(1000)[:number_of_sampled_points]
  points = points[:, points_to_keep, :]
  y = y[points_to_keep]


# [1, n_points, k]
y_pred = model(points)

# [n_points]
y_pred = y_pred.argmax(-1).squeeze(0)


plot3d_shapes([points[0], points[0]], [y, y_pred], ['Ground truth', 'Prediction'])