# Classification

### This notebook replicates the shape classification experiments in section 6.2 of "Field Convolutions for Surface CNNs" (Mitchel et al. 2021).

In [None]:
# File reading and progressbar
import os
import os.path as osp
import progressbar

# Numpy
import numpy as np

# PyTorch 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd

# PyTorch Geometric - used for data loading/processing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader

# Field Convolution modules
from nn import FCResNetBlock, LiftBlock, FieldConv

# Feature magnitudes
from utils.field import softAbs

# Transforms
from transforms import FCPrecomp, computeLogXPort, SupportGraph, NormalizeArea, NormalizeAxes

# Load the SHREC '11 shape classification dataset 
from datasets import SHREC11

# Clear your cache
torch.cuda.empty_cache()

## Hyperparameters

In [None]:
# Band-limit for field convolution filters
band_limit = 2

# Number of radial samples
n_rings = 6

# Filter type (see /nn/field_conv.py)
ftype = 1

# Number of channels in the network
nf = 32

# Filter support radius
epsilon = 0.2;

## Data pre-processing and loading

In [None]:
# Path where dataset is stored
# The path to the .zip file containing the data should be
# /data/SHREC11/raw/SHREC11.zip
path = osp.join('data', 'SHREC11')

# Pre-processing operations
# Normalize meshes to have unit surface area
# Sample points on meshes and compute convolution support edges
# Compute logarithm maps + parallel transport
pre_transform = T.Compose((
    NormalizeArea(),
    SupportGraph(epsilon=epsilon),
    computeLogXPort()
))


# Apply a random rotation and scale every time a shape is drawn from the dataloader
transform = T.Compose((
    T.RandomScale((0.85, 1.15)),
    T.RandomRotate(45, axis=0),
    T.RandomRotate(45, axis=1),
    T.RandomRotate(45, axis=2))
)


# Load test and train splits
test_dataset = SHREC11(path, False, pre_transform=pre_transform)
test_loader = DataLoader(test_dataset, batch_size=1)
train_dataset = SHREC11(path, True, pre_transform=pre_transform, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

n_classes = train_dataset.num_classes

## FCNet

##### Two FCResNet blocks, followed by a field convolution to map network features to class predictions.  A learnable gradient-like operation is used to lift scalar features to isometry-equivariant tangent vector fields at the beginning of the network

In [None]:
## Organizes edge data at run time to expidte convolutions
organizeEdges = FCPrecomp(band_limit=band_limit, n_rings=n_rings, epsilon=epsilon)

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        ## Learned 'gradient', lifting scalar features to tangent vector features
        ## at the beginning of the network
        self.lift = LiftBlock(3, nf, n_rings=n_rings, ftype=ftype)
        
        ## FCNet - two FCResNet Blocks, followed by a field convolution 

        self.resnet1 = FCResNetBlock(nf, nf, band_limit=band_limit, 
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet2 = FCResNetBlock(nf, nf, band_limit=band_limit, 
                                     n_rings=n_rings, ftype=ftype)
                
        
        self.conv_out = FieldConv(nf, n_classes, band_limit=band_limit,
                                 n_rings=n_rings, ftype=ftype)
        
        ## Bias applied to output prediction
        self.bias = nn.Parameter(torch.Tensor(1, n_classes))
        torch.nn.init.zeros_(self.bias)


    def forward(self, data):
        
        ##########################
        ### Organize edge data ###
        ##########################
        supp_edges, supp_sten, _, _ = organizeEdges(data)
        
        attr_lift = (supp_edges, supp_sten[..., band_limit:(band_limit+2)])
        attr_conv = (supp_edges, supp_sten)
        
        
        #############################################
        ### Lift scalar features to vector fields ###
        #############################################
        
        x = data.pos[data.sample_idx, :]
        
        x = self.lift(x, *attr_conv)
        
        ##########################
        ### Field Convolutions ###
        ##########################
        
        x = self.resnet1(x, *attr_conv) 
        
        x = self.resnet2(x, *attr_conv)
        
        ################################
        ### Mean pool for prediction ###
        ################################
        
        x = self.conv_out(x, *attr_conv)
        
        x = torch.mean(softAbs(x), dim=0, keepdim=True)
        
        return x + self.bias;

## Training 

In [None]:
# Train on the GPU
device = torch.device('cuda')
model = Net().to(device)

# ADAM Optimizer, lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

## Cross Entropy Loss
loss = torch.nn.CrossEntropyLoss()

In [None]:
## Training function
## Optional batch_step parameter for gradient accumulation (not used in the paper)
def train(batch_step=1):

    model.train()
    
    # Sort out progress bar, displays average loss over last 10 samples
    wW = 10;
    window = torch.FloatTensor(wW).fill_(0)
    
    n_data = train_loader.__len__()
    widgets = [progressbar.Percentage(), progressbar.Bar(), 
              progressbar.AdaptiveETA(), ' | ', progressbar.Variable('Loss'),]

    bar = progressbar.ProgressBar(max_value=n_data, widgets=widgets)

    
    ## Zero-out
    optimizer.zero_grad() 

    i = 0;
    for data in train_loader:
            
        pred = model(data.to(device))

        L = loss(pred, data.y.to(device))

        if (i < wW):
            window[i] = L.item()
            wAvg = torch.mean(window[:i])
        else:
            window = torch.cat((window[1:], torch.tensor([L.item()])), dim=0)
            wAvg = torch.mean(window)

        # Update progress bar
        i = i + 1
        bar.update(i, Loss = torch.mean(window[:i]))
        
        ## Update loss
        L = L / batch_step;
        L.backward()

        if ( i % batch_step == 0 or i == n_data ):
            optimizer.step()
            model.zero_grad()


## Testing

In [None]:
## Compute classification accuracy on test set   
def test():
    model.eval()
    correct = 0
    total_num = 0
    for i, data in enumerate(test_loader):
        pred = model(data.to(device)).max(1)[1]
        correct += pred.eq(data.y).sum().item()
        total_num += 1
    return correct / total_num

## Train, then test
###### We train for 30 epochs, as in the paper. 

In [None]:
print('Training...')

for epoch in range(30):
    print("Epoch {}".format(epoch), flush=True)
    train()
    
acc = test()
print("Test accuracy: {:06.4f}".format(acc), flush=True)