# Correspondence

### This notebook replicates the dense correspondence experiments in section 6.4 of "Field Convolutions for Surface CNNs" (Mitchel et al. 2021).


In [None]:
# File reading and progressbar
import os
import os.path as osp
import progressbar

# Numpy
import numpy as np

# PyTorch 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd

# PyTorch Geometric - used for data loading/processing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader

# Field Convolution modules
from nn import FCResNetBlock, LiftBlock, ECHOBlock, TangentPerceptron

# Transforms
from transforms import FCPrecomp, computeLogXPort, SupportGraph, NormalizeArea, NormalizeAxes

# Load the remeshed FAUST dataset (Donati et al. 2020)
from datasets import FAUSTRM

# Clear your cache
torch.cuda.empty_cache()

## Hyperparameters

In [None]:
# Band-limit for field convolution filters
band_limit = 1

# Number of radial samples
n_rings = 3

# Filter type (see /nn/field_conv.py)
ftype = 1

# Number of channels in the network
nf = 32

# Number of ECHO descriptors to compute in the last layer of the network
n_des = 12;

# Number of descriptor bins per unit radius
# Descriptor resolution will be approximately  PI * (n_bins + 0.5) * (n_bins + 0.5)
n_bins = 2;

# Filter support radius
epsilon = 0.0425;

## Data pre-processing and loading

In [None]:
# Path where dataset is stored
# The path to the .zip file containing the data should be
# /data/FAUSTRM/raw/FAUSTRM.zip
path = osp.join('data', 'FAUSTRM')


# Pre-processing operations
# Compute convolution support edges, meshes are processed at full resolution 
# Compute logarithm maps + parallel transport
pre_transform = T.Compose((
    SupportGraph(epsilon=epsilon),
    computeLogXPort()
))

# Apply a random rotation every time a shape is drawn from the dataloader
transform = T.Compose((T.Center(),
    T.RandomRotate(45, axis=0),
    T.RandomRotate(45, axis=1),
    T.RandomRotate(45, axis=2))
)

# Load test and train splits
train_dataset = FAUSTRM(path, True, pre_transform=pre_transform, transform=transform)
test_dataset = FAUSTRM(path, False, pre_transform=pre_transform, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)

n_classes = 4999


## FCNet

##### A sucession of FCResNet blocks with an ECHO block as the final layer.  A learnable gradient-like operation is used to lift scalar features to isometry-equivariant tangent vector fields at the beginning of the network

In [None]:
## Organizes edge data at run time to expidte convolutions
organizeEdges = FCPrecomp(band_limit=band_limit, n_rings=n_rings, epsilon=epsilon)

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        ## Learned 'gradient', lifting scalar features to tangent vector features
        ## at the beginning of the network
        
        self.lift = LiftBlock(3, 16, n_rings=n_rings, ftype=ftype)
        
        ## FCNet - eight FCResNet Blocks, followed by an ECHO block 
        self.resnet1 = FCResNetBlock(16, nf, band_limit=band_limit, 
                                     n_rings=n_rings, ftype=ftype)
                
        self.resnet2 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet3 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet4 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)

        self.resnet5 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet6 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet7 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet8 = FCResNetBlock(nf, 16, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype, frontload=True)

        ## ECHO Block
        self.echo = ECHOBlock(16, nf, n_des=n_des, n_bins=n_bins, 
                             band_limit=band_limit, n_rings=n_rings, ftype=ftype)
        
        ## 'Meta' residual connections (in addition to those alredy inside the FCResNet blocks)
        self.res1 = TangentPerceptron(16, nf)
        
        self.res2 = TangentPerceptron(nf, nf)
        
        self.res3 = TangentPerceptron(nf, nf)
        
        self.res4 = TangentPerceptron(nf, 16)
        
        
        ## Dropout layer
        self.D = nn.Dropout(p=0.5)

        ## Final linear layer
        self.lin1 = nn.Linear(nf, 256)
        self.lin2 = nn.Linear(256, n_classes)

    def forward(self, data):
        
        ##########################
        ### Organize edge data ###
        ##########################
        supp_edges, supp_sten, ln, wxp = organizeEdges(data)
        
        attr_lift = (supp_edges, supp_sten[..., band_limit:(band_limit+2)])
        attr_conv = (supp_edges, supp_sten)
        attr_echo = (supp_edges, supp_sten, ln, wxp)
        
        
        #############################################
        ### Lift scalar features to vector fields ###
        #############################################
        
        x = data.pos[data.sample_idx, :]
        
        x1 = self.lift(x, *attr_lift)

        ##########################
        ### Field Convolutions ###
        ##########################
        
        x = self.resnet1(x1, *attr_conv) 
        
        x2 = self.resnet2(x, *attr_conv) + self.res1(x1)
        
        x = self.resnet3(x2, *attr_conv)
        
        x3 = self.resnet4(x, *attr_conv) + self.res2(x2)
        
        x = self.resnet5(x3, *attr_conv)
        
        x4 = self.resnet6(x, *attr_conv) + self.res3(x3)
        
        x = self.resnet7(x4, *attr_conv)
        
        x = self.resnet8(x, *attr_conv) + self.res4(x4)
        
        #########################################################
        ### Compute ECHO descriptors and feed to dense layers ###
        #########################################################
        
        x = self.echo(x, *attr_echo)
                    
        x = F.relu(self.lin1(x))
        x = self.D(x)
        x = self.lin2(x)
             
        return x;

## Training

In [None]:
# Train on the GPU
device = torch.device('cuda')
model = Net().to(device)

# ADAM Optimizer, lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

## Cross Entropy Loss
loss = torch.nn.CrossEntropyLoss()

In [None]:
## Training function
## Optional batch_step parameter for gradient accumulation (not used in the paper)
def train(batch_step=1):

    model.train()
    
    # Sort out progress bar, displays average loss over last 50 samples
    wW = 50;
    window = torch.FloatTensor(wW).fill_(0)
    
    n_data = train_loader.__len__()
    widgets = [progressbar.Percentage(), progressbar.Bar(), 
              progressbar.AdaptiveETA(), ' | ', progressbar.Variable('Loss'),]

    bar = progressbar.ProgressBar(max_value=n_data, widgets=widgets)

    
    ## Zero-out
    optimizer.zero_grad() 

    i = 0;
    for data in train_loader:
            
        pred = model(data.to(device))

        L = loss(pred, data.y.to(device).squeeze())

        if (i < wW):
            window[i] = L.item()
            wAvg = torch.mean(window[:i])
        else:
            window = torch.cat((window[1:], torch.tensor([L.item()])), dim=0)
            wAvg = torch.mean(window)

        # Update progress bar
        i = i + 1
        bar.update(i, Loss = torch.mean(window[:i]))
        
        ## Update loss
        L = L / batch_step;
        L.backward()

        if ( i % batch_step == 0 or i == n_data ):
            optimizer.step()
            model.zero_grad()


## Testing

In [None]:
## Compute the mean loss on the test set
def test():    
    model.eval()
    totalL = 0

    for i, data in progressbar.progressbar(enumerate(test_loader)):
        with torch.no_grad():

            L = loss(model(data.to(device)), data.y.to(device).squeeze())

            totalL += L.item()

    return totalL / 20;

## Train, then test
###### We train for 60 epochs, as in the paper. 

In [None]:
print('Training...')

for epoch in range(60):
    
    print('Epoch {}'.format(epoch))
    train()
    
    # Decay the learning rate for the last phase of training
    if (epoch == 40):
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.001


testL = test();
print('Mean test loss: {}'.format(testL), flush=True)