# Shape classification

The notebooks in this folder replicate the experiments as performed for [CNNs on Surfaces using Rotation-Equivariant Features](https://doi.org/10.1145/3386569.3392437).

The current notebook replicates the shape classification experiments from section `5.2 Comparisons`.

## Imports
We start by importing dependencies.

In [1]:
# File reading and progressbar
import os.path as osp
import progressbar

# PyTorch and PyTorch Geometric dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.nn.inits import zeros
from torch import autograd

import numpy as np
# Harmonic Surface Networks components
# Layers
from nn import (ECResNetBlock, TransField, ExtConvCplx, LiftBlock, ECHOBlock,
                ParallelTransportPool, ParallelTransportUnpool,
                TangentLin, TangentNonLin, VectorDropout)
# Utility functions
from utils.harmonic import magnitudes, norm2D
# Rotated MNIST dataset
from datasets import Shrec16
# Transforms
from transforms import (HarmonicPrecomp, VectorHeat, MultiscaleRadiusGraph, 
                        ScaleMask, FilterNeighbours, NormalizeArea, NormalizeAxes, Subsample)

## Settings
Next, we set a few parameters for our network. You can change these settings to experiment with different configurations of the network. Right now, the settings are set to the ones used in the paper.

In [2]:
# Band-limit for extended convolution
band_limit = 2

# Number of rings in the radial profile
n_corr_rings = 6

# Number of conv rings
n_conv_rings = 6

# Learn radial offset for correlations
offset = True;

# Number of filters per block
nf = [16, 32]

n_des = 16
n_bins = 3

desDim = 1;

# Ratios used for pooling
ratios=[1, 0.25]

# Radius of convolution for each scale
radii = [0.2, 0.4]

# Number of datasets per batch
batch_size = 1


## Dataset
To get our dataset ready for training, we need to perform the following steps:
1. Provide a path to load and store the dataset.
2. Define transformations to be performed on the dataset:
    - A transformation that computes a multi-scale radius graph and precomputes the logarithmic map.
    - A transformation that masks the edges and vertices per scale and precomputes convolution components.
3. Assign and load the datasets.

In [3]:
# 1. Provide a path to load and store the dataset.
# Make sure that you have created a folder 'data' somewhere
# and that you have downloaded and moved the raw datasets there
path = osp.join('data', 'shrec')

# 2. Define transformations to be performed on the dataset:
# Transformation that computes a multi-scale radius graph and precomputes the logarithmic map.
pre_transform = T.Compose((
    NormalizeArea(),
    MultiscaleRadiusGraph(ratios, radii, loop=True, flow='target_to_source'),
    VectorHeat(max_lvl=len(ratios)-1),
))
# Apply a random scale and random rotation to each shape
transform = T.Compose((
    T.RandomScale((0.85, 1.15)),
    T.RandomRotate(45, axis=0),
    T.RandomRotate(45, axis=1),
    T.RandomRotate(45, axis=2))
)

# Transformations that masks the edges and vertices per scale and precomputes convolution components.
scale0_transform = T.Compose((
    ScaleMask(0),
    FilterNeighbours(radii[0]),
    HarmonicPrecomp(n_conv_rings, n_corr_rings, band_limit, max_r=radii[0]))
)
scale1_transform = T.Compose((
    ScaleMask(1),
    FilterNeighbours(radii[1]),
    HarmonicPrecomp(n_conv_rings, n_corr_rings, band_limit, max_r=radii[1]))
    
)

# 3. Assign and load the datasets.
test_dataset = Shrec16(path, False, pre_transform=pre_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
train_dataset = Shrec16(path, True, pre_transform=pre_transform, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

n_classes = train_dataset.num_classes

## Network architecture
Now, we create the network architecture by creating a new `nn.Module`, `Net`. We first setup each layer in the `__init__` method of the `Net` class and define the steps to perform for each batch in the `forward` method. We use the left part of the U-ResNet architecture, with half of the ResNet blocks from the following figure:

<img src="img/resnet_architecture.png" width="800px" />

Let's get started!

In [4]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.lift = LiftBlock(3, nf[1], n_corr_rings = n_corr_rings)
        
        self.resnet1 = ECResNetBlock(nf[1], nf[1], band_limit=band_limit, n_conv_rings = n_conv_rings)
        self.resnet2 = ECResNetBlock(nf[1], nf[1], band_limit=band_limit, n_conv_rings=n_conv_rings)
                
        self.conv_final = ExtConvCplx(nf[1], n_classes, band_limit=band_limit, n_rings=n_conv_rings, offset=offset)
       # self.echo = ECHOBlock(nf[0], n_des, n_bins, n_classes, band_limit, n_conv_rings, classify=True)

        self.bias = nn.Parameter(torch.Tensor(n_classes))
        zeros(self.bias)


    def forward(self, data):
        
        ###############
        ### Level 1 ###
        ###############
        
        data_scale0 = scale0_transform(data)
        attributes = (data_scale0.edge_index, data_scale0.pcmp_scatter, data_scale0.pcmp_gather, 
                      data_scale0.connection)
        
        attr_grad = (data_scale0.edge_index, data_scale0.pcmp_gather)
        attr_conv = (data_scale0.edge_index, data_scale0.pcmp_scatter, data_scale0.connection)
        attr_echo = (data_scale0.edge_index, data_scale0.pcmp_scatter, data_scale0.pcmp_echo, data_scale0.connection)
        
        x = data.pos
        
        # x = self.class_block(x, *attributes)
        
        x = self.lift(x, *attr_grad)
        x = self.resnet1(x, *attr_conv)
        x = self.resnet2(x, *attr_conv)
        #x = self.echo(x, *attr_echo)
        x = self.conv_final(x, *attr_conv)
        
        # Linear transformation from input descriptors to nf[0] features
        #x = self.nonlin(self.lift(x, data_scale0.edge_index, data_scale0.pcmp_gather))

        # ResNet Block 1
        # Select only the edges and precomputed components of the first scale
        #x = self.resnet_block11(x, *attributes)
        
        #x = self.class_block(x, *attributes)

        ### Pooling ###
        # Apply parallel transport pooling
       # x, data, data_pooled = self.pool(x, data)
        
        ###############
        ### Level 2 ###
        ###############
        
        # Store edge_index and precomputed components of the second scale
       # attributes_pooled = (data_pooled.edge_index, data_pooled.pcmp_scatter, drepresentationalata_pooled.pcmp_gather, 
                             #data_pooled.connection)
                
       # x = self.resnet_block21(x, *attributes_pooled)


        #x = self.funnel(x, data_scale0.edge_index, data_scale0.pcmp_scatter, data_scale0.connection)
        
        # Take radial component from features and sum streams

        # Global mean pool
        x = torch.mean(norm2D(x), dim=0, keepdim=True)
        #x = torch.mean(x, dim=0, keepdim=True)
        x = x + self.bias
        #x = torch.mean(x, dim=0, keepdim=True)
        return F.log_softmax(x, dim=1)

If you want to test the network without applying parallel transport, remove the `connection` argument from convolutions.

## Training and testing

Phew, we're through the hard part. Now, let's get to training. First, move the network to the GPU and setup an optimizer.

In [5]:
# We want to train on a GPU. It'll take a long time on a CPU
device = torch.device('cuda')
# Move the network to the GPU
model = Net().to(device)
# Set up the ADAM optimizer with learning rate of 0.0076 (as used in H-Nets)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

Next, define a training and test function.

In [6]:
def train(epoch, acc):
    # Set model to 'train' mode
    model.train()

    # For the two stream architecture,
    # it helps to decrease the learning rate earlier to stabilise training
    
    if epoch > 5 or acc >= 0.9:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.001
     
    '''
    if epoch > 20 or acc >= 0.97:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.0005
     '''       
            

    # Sort out progress bar
    n_data = train_loader.__len__()
    widgets = [progressbar.Percentage(), progressbar.Bar(), 
              progressbar.AdaptiveETA(), ' | Loss:', progressbar.Variable('loss'),]

    bar = progressbar.ProgressBar(max_value=n_data, widgets=widgets)

    i = 0;
    totalL = 0.0;
    for data in train_loader:
    # Move training data to the GPU and optimize parameters
        #with autograd.detect_anomaly():
            optimizer.zero_grad()
            L = F.nll_loss(model(data.to(device)), data.y);
            i = i + 1
            totalL += L.item()
            
            bar.update(i, loss = (totalL / i) )
            L.backward()
            
            optimizer.step()
            
        
def test():
    # Set model to 'evaluation' mode
    model.eval()
    correct = 0
    total_num = 0
    for i, data in enumerate(test_loader):
        pred = model(data.to(device)).max(1)[1]
        correct += pred.eq(data.y).sum().item()
        total_num += 1
    return correct / total_num

Train for 50 epochs.

In [None]:
print('Start training...')
acc = 0
for epoch in range(50):
    train(epoch, acc)
    test_acc = test()
    acc = test_acc if test_acc > acc else acc
    print("Epoch {} - Test: {:06.4f}".format(epoch, test_acc))
    if (acc > 0.999):
        break;

  0%|                                       |ETA:  00:00:00 | Loss:loss:    3.4

Start training...


  0%|                                       |ETA:   0:00:18 | Loss:loss:    3.0

Epoch 0 - Test: 0.2000


  0%|                                       |ETA:   0:00:18 | Loss:loss:  0.192

Epoch 1 - Test: 0.6917


  0%|                                       |ETA:   0:00:18 | Loss:loss: 0.0546

Epoch 2 - Test: 0.8250


  0%|                                      |ETA:   0:00:16 | Loss:loss: 0.00323

Epoch 3 - Test: 0.8750


  0%|                                       |ETA:   0:00:16 | Loss:loss:  0.378

Epoch 4 - Test: 0.8000


  0%|                                      |ETA:   0:00:16 | Loss:loss: 0.00846

Epoch 5 - Test: 0.9833


  0%|                                      |ETA:   0:00:16 | Loss:loss: 0.00275

Epoch 6 - Test: 0.9833


 80%|###############################        |ETA:   0:00:06 | Loss:loss: 0.0129