# Point-based Learning Methods
For this exercise you will implement the PointNet architecture as well as you will implement a point-based convolution method.

To create a dataset of point clouds you need Open3D. This can be installed via `conda install open3d`

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import openmesh as om
import numpy as np
import k3d
import matplotlib.pyplot as plt
from jupyterplot import ProgressPlot
from tqdm.notebook import tqdm
from sklearn.neighbors import KDTree

from dataset import ModelNet10

## PointNet
The PointNet architecture you will implement in this exercise is a slightly simplified version of the one presented in the lecture. We will omit the T-Net modules for point and feature alignment.
Your task is therefore to implement a network, that transforms each point of a point cloud individually. and then take the maximum of each feature value over all points. The number of layers and layer parameters should be the same as presented in the lecture.

We will test your PointNet implementation on the ModelNet10 dataset. We will first need to download the dataset and as it contains meshes, sample it as well. This will take a couple of minutes.

Make sure to upload your `best_val.ckpt` checkpoint file so that we do not have to retrain your model. If the file is not included we **cannot** give you any points for this task.

In [None]:
class PointNet(nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()

        ### BEGIN SOLUTION
        self.local = nn.Sequential(
            nn.Conv1d(3,64,1, bias=False), # b x 3 x n
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64,64,1, bias=False),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64,64,1, bias=False),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64,128,1, bias=False),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128,1024,1, bias=False),
            nn.BatchNorm1d(1024),
            nn.ReLU() # b x 1024 x n
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, n_classes),
        )
        ### END SOLUTION

    def forward(self, x):
        ### BEGIN SOLUTION
        x = self.local(x)
        x = torch.max(x, dim=2)[0]
        x = self.classifier(x)
        ### END SOLUTION
        return x

In [None]:
batch_size = 64 # you can change the batch size depending on your memory requirements
train_data = ModelNet10('./ModelNet10', mode="train")
val_data = ModelNet10('./ModelNet10', mode="val")
test_data = ModelNet10('./ModelNet10', mode="test")
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
val_loader = torch.utils.data.DataLoader(val_data, shuffle=False, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_data, shuffle=False, batch_size=batch_size)

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
model = PointNet(10).to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
n_epochs = 10

pp = ProgressPlot(plot_names=["loss", "accuracy"], line_names=["train", "val"],
                  x_lim=[0, n_epochs-1], y_lim=[[0,1], [0,1]])

best_val_acc = -1

pbar = tqdm(range(n_epochs))
for e in pbar:
    train_loss = 0
    train_acc = 0
    model.train()
    for (x,y) in train_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = F.cross_entropy(pred, y)
        optim.zero_grad()
        loss.backward()
        optim.step()
        train_loss += loss.item()
        train_acc += (pred.max(-1).indices == y).float().sum().item()
    train_loss /= len(train_loader)
    train_acc /= len(train_data)
    
    model.eval()
    val_loss = 0
    val_acc = 0
    with torch.no_grad():
        for (x,y) in val_loader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = F.cross_entropy(pred, y)
            val_loss += loss.item()
            val_acc += (pred.max(-1).indices == y).float().sum().item()
        val_loss /= len(val_loader)
        val_acc /= len(val_data)
        if val_acc > best_val_acc:
            torch.save({
            'epoch': e,
            'model_state_dict': model.state_dict(),
            'optim_state_dict': optim.state_dict(),
            'val_acc': val_acc,
            }, "best_val.ckpt")
    
    pp.update([[train_loss, val_loss], [train_acc, val_acc]])
    pbar.set_description(f"train loss: {train_loss:.4f}, train acc.: {train_acc:.4f}")
pp.finalize()

In [None]:
checkpoint = torch.load("best_val.ckpt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
test_acc = 0
with torch.no_grad():
    for (x,y) in tqdm(test_loader):
            x, y = x.to(device), y.to(device)
            pred = model(x)
            test_acc += (pred.max(-1).indices == y).float().sum().item()
test_acc /= len(test_data)
print(f"test acc.: {test_acc}")
assert test_acc >= 0.7

## Point Convolutions
In this exercise you will implement point convolutions as presented in the lecture.
For this you need to implement three functions:

- `neighbourhood(self, points)` should return the indices of the `self.n_neighbours` closest points to each sample point. The output should be of size (n_points, n_neighbours)
- `correlation(self, positions, neighbours)` should implement the (linear) distance based point correlation as described in slide 28 of "Point based Approaches". The output should be of size (n_points, n_kernel_points, n_neighbours). [Here](./correlation.html) you can see the nearest neighbours, kernel points and correlation for a specific vertex.
- `forward(self, features, points)` implements the complete PointConvolution, using the previous functions. It should return a tensor of size (out_channel, n_points). The result should look like [this](./result.html).

In [2]:
def meshgrid(s, device=torch.device('cpu')):
    r = torch.arange(s, device=device, dtype=torch.float)
    x = r[:, None, None].expand(s, s, s)
    y = r[None, :, None].expand(s, s, s)
    z = r[None, None, :].expand(s, s, s)
    return torch.stack([x, y, z], 0) / (s - 1) - 0.5

class PointConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, n_neighbours=30):
        super().__init__()

        self.kernel_size = kernel_size
        self.n_neighbours = n_neighbours
        self.kernel_radius = 0.1
        self.sigma = 0.2
        self.kernel_points = meshgrid(self.kernel_size).view(3,-1) * self.kernel_radius
        
        # in practice we would initialize the weights randomly
        self.weights = nn.Parameter(torch.ones((self.kernel_size**3, in_channels, out_channels), 
                                               dtype=torch.float32))
        
        
    def neighbourhood(self, points):
        ### BEGIN SOLUTION
        tree = KDTree(points.transpose(1,0)) 
        neighbours = tree.query(points.transpose(1,0), k=self.n_neighbours)
        return neighbours[1]
        ### END SOLUTION
        
    def correlation(self, positions, neighbours):
        ### BEGIN SOLUTION
        relative = neighbours - positions.unsqueeze(-1).repeat([1,1,self.n_neighbours]) #3, n, m
        differences = relative[:,:,:,None] - self.kernel_points[:,None,None,:] #3, n, m, k
        distances = torch.norm(differences, dim=0) #n, m, k
        correlation = torch.clamp(1 - distances / self.sigma, min=0.0)
        correlation = correlation.transpose(1, 2) #n, k, m
        return correlation
        ### END SOLUTION

    def forward(self, features, points):
        ### BEGIN SOLUTION
        # compute neighbourhood (should return indices)
        neighbour_indices = self.neighbourhood(points)
        neighbour_points = points[:, neighbour_indices]
        neighbour_features = features[:, neighbour_indices].permute(1,2,0) # [n_points, n_neighbours, in_channels]
        
        # compute correlation [n_points, n_kernel_points, n_neighbours]
        h = self.correlation(points, neighbour_points)
        
        # Apply distance weights [n_points, n_kernel_points, in_channels]
        weighted_features = torch.bmm(h, neighbour_features)

        # Apply network weights [n_kernel_points, n_points, out_channels]
        weighted_features = weighted_features.permute((1, 0, 2)) # [n_kernel_points, n_points, in_channels]
        kernel_outputs = torch.matmul(weighted_features, self.weights) # [n_kernel_points, in_channels, out_channels]

        # Convolution sum [out_channels, n_points]
        return torch.sum(kernel_outputs, dim=0).transpose(1,0)
        ### END SOLUTION

In [3]:
mesh = om.read_trimesh("spot.obj")
pts = torch.from_numpy(mesh.points()).float().transpose(1,0)

In [4]:
%%timeit
conv = PointConvolution(8,16, n_neighbours=30)
f = conv(torch.ones(8,2930), pts)

216 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
conv = PointConvolution(1,1, n_neighbours=50)
neighbours = conv.neighbourhood(pts)
neighbour_points = pts[:, neighbours]
h = conv.correlation(pts, neighbour_points)

In [None]:
colors = torch.zeros(pts.shape[1])
colors[neighbours[0]] = h[0,13]
plot = k3d.plot()
plot += k3d.mesh(mesh.points(), mesh.fv_indices(), attribute=colors, 
                 color_map=k3d.colormaps.matplotlib_color_maps.viridis)
plot += k3d.points(conv.kernel_points.transpose(0,1) + pts.transpose(1,0)[0], point_size=0.01)
plot += k3d.points(neighbour_points[:,0].transpose(0,1), point_size=0.01, color=0xFF0000)
plot

In [None]:
conv = PointConvolution(1,1, n_neighbours=50)
neighbours = conv.neighbourhood(pts)
assert((neighbours[0][:10] == [0, 764, 812, 767, 813, 1158, 768, 1165, 1159, 197]).all())
assert((neighbours[2431][:10] == [2431, 1307, 2428, 1308, 2430, 73, 624, 2424, 2429, 2422]).all())
### BEGIN HIDDEN TESTS
assert((neighbours[1578][10:20] == [317, 381, 1576, 422, 1568, 1638, 1569, 1227, 1226, 1479]).all())
### END HIDDEN TESTS

In [None]:
conv = PointConvolution(1,1, n_neighbours=50)
neighbours = conv.neighbourhood(pts)
neighbour_points = pts[:, neighbours]
h = conv.correlation(pts, neighbour_points)

np.testing.assert_approx_equal(h[0,0,0], 0.5670, significant=4)
np.testing.assert_approx_equal(h[0,0].sum(), 7.7668, significant=4)
np.testing.assert_approx_equal(h[579,14,8], 0.3578, significant=4)
np.testing.assert_approx_equal(h[579,14].sum(), 8.3978, significant=4)

### BEGIN HIDDEN TESTS
np.testing.assert_approx_equal(h[2136,24,21], 0.5467, significant=4)
np.testing.assert_approx_equal(h[2136,24].sum(), 14.2310, significant=4)
### END HIDDEN TESTS

In [None]:
conv = PointConvolution(1,1, n_neighbours=50)
f = conv(torch.ones(1,2930), pts)

In [None]:
colors = torch.zeros(pts.shape[0])
colors = f.squeeze().detach()
plot = k3d.plot()
plot += k3d.mesh(mesh.points(), mesh.fv_indices(), attribute=colors, 
                 color_map=k3d.colormaps.matplotlib_color_maps.viridis)
plot

In [None]:
conv = PointConvolution(8,16, n_neighbours=30)
f = conv(torch.ones(8,2930), pts)

np.testing.assert_approx_equal(f.sum(), 140892496, significant=9)
np.testing.assert_approx_equal(f[:,0].sum(), 26717.59375, significant=9)
np.testing.assert_approx_equal(f[0,0], 1669.84960, significant=9)

### BEGIN HIDDEN TESTS
np.testing.assert_approx_equal(f[:,2458].sum(), 51558.9453, significant=9)
np.testing.assert_approx_equal(f[12,2458], 3222.43432, significant=9)
### END HIDDEN TESTS