# Feature Matching

### This notebook replicates the feature matching experiments  in section 6.5 of "Field Convolutions for Surface CNNs" (Mitchel et al. 2021).

## Dependencies

In [None]:
# File reading and progressbar
import os
import os.path as osp
import progressbar

# Numpy
import numpy as np

# Random
import random

# PyTorch 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd

# PyTorch Geometric - used for data loading/processing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader

# Field Convolution modules
from nn import FCResNetBlock, LiftBlock, TangentPerceptron, TwinLoss, TwinEval

# Feature magnitudes
from utils.field import softAbs

# Transforms
from transforms import FCPrecomp, computeLogXPort, SupportGraph

# Load the Isometric & Non-Isometric Shape Correspondence dataset (Dyke et al., 2019)
from datasets import SHREC19

# Clear your cache
torch.cuda.empty_cache()

## Hyperparameters

In [None]:
# Band-limit for field convolution filters
band_limit = 1

# Number of radial samples
n_rings = 6

# Filter type (see /nn/field_conv.py)
ftype = 1

# Number of channels in the network
nf = 32

# Filter support radius
epsilon = 0.1;

# Number of sample points on meshes
n_samples = 2048

## Data pre-processing and loading

In [None]:
# Path where dataset is stored
# The path to the .zip file containing the data should be
# /data/SHREC19/raw/SHREC19.zip

path = osp.join('data', 'SHREC19')

# Pre-processing operations
# Sample points on meshes and compute convolution support edges
# Compute logarithm maps + parallel transport
pre_transform = T.Compose((
    SupportGraph(epsilon=epsilon, sample_n=n_samples),
    computeLogXPort()
))

# Apply a random rotation every time a shape is drawn from the dataloader
transform = T.Compose((
    T.Center(),
    T.RandomRotate(45, axis=0),
    T.RandomRotate(45, axis=1),
    T.RandomRotate(45, axis=2))
)

# Load test and train splits (sorce shapes + target shapes)
trainS_dataset = SHREC19(path, 0, n_samples=n_samples, pre_transform=pre_transform, transform=transform)
trainT_dataset = SHREC19(path, 1, n_samples=n_samples, pre_transform=pre_transform, transform=transform)

testS_dataset = SHREC19(path, 2, n_samples=n_samples, pre_transform=pre_transform)
testT_dataset = SHREC19(path, 3, n_samples=n_samples, pre_transform=pre_transform)

## FCNet
##### Eight FCResNet blocks, followed by a linear layer to map network features a 16-dimensional descriptor.  A learnable gradient-like operation is used to lift scalar features to isometry-equivariant tangent vector fields at the beginning of the network

In [None]:
## Organizes edge data at run time to expidte convolutions
organizeEdges = FCPrecomp(band_limit=band_limit, n_rings=n_rings, epsilon=epsilon)

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        ## Learned 'gradient', lifting scalar features to tangent vector features
        ## at the beginning of the network
        
        self.lift = LiftBlock(3, 16, n_rings=n_rings, ftype=ftype)
        
        ## FCNet - eight FCResNet Blocks
        self.resnet1 = FCResNetBlock(16, nf, band_limit=band_limit, 
                                     n_rings=n_rings, ftype=ftype)
                
        self.resnet2 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet3 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet4 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)

        self.resnet5 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet6 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet7 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet8 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)


        
        ## 'Meta' residual connections (in addition to those alredy inside the FCResNet blocks)
        self.res1 = TangentPerceptron(16, nf)
        
        self.res2 = TangentPerceptron(nf, nf)
        
        self.res3 = TangentPerceptron(nf, nf)
        
        self.res4 = TangentPerceptron(nf, nf)
        
        ## Linear layer mapping features to 16-dimensional descriptors
        self.out = TangentPerceptron(nf, 16)


    def forward(self, data):
        
        ##########################
        ### Organize edge data ###
        ##########################
        supp_edges, supp_sten, _, _ = organizeEdges(data)
        
        attr_lift = (supp_edges, supp_sten[..., band_limit:(band_limit+2)])
        attr_conv = (supp_edges, supp_sten)
        
        
        #############################################
        ### Lift scalar features to vector fields ###
        #############################################
        
        x = data.pos[data.sample_idx, :]
        
        x1 = self.lift(x, *attr_lift)

        ##########################
        ### Field Convolutions ###
        ##########################
        
        x = self.resnet1(x1, *attr_conv) 
        
        x2 = self.resnet2(x, *attr_conv) + self.res1(x1)
        
        x = self.resnet3(x2, *attr_conv)
        
        x3 = self.resnet4(x, *attr_conv) + self.res2(x2)
        
        x = self.resnet5(x3, *attr_conv)
        
        x4 = self.resnet6(x, *attr_conv) + self.res3(x3)
        
        x = self.resnet7(x4, *attr_conv)
        
        x = self.resnet8(x, *attr_conv) + self.res4(x4)
        
        #########################################
        ### Map features to output descriptor ###
        #########################################
        
        return softAbs(self.out(x))

## Training

In [None]:
# Train on the GPU
device = torch.device('cuda')
model = Net().to(device)

# ADAM Optimizer, lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## Twin loss 
loss = TwinLoss();
ev = TwinEval()

# Randomly sample 512 pairs of corresponding and non-corresponding points
n_pairs = 512

In [None]:
## Corresponding pairs are computed during data pre-processing
## To get non-corresponding pairs, we consider the complement at run time
def getNullPairs(pos_pairs, nSamples=n_samples):
    pos_lin = pos_pairs[0] * nSamples + pos_pairs[1]
    all_lin = torch.arange(nSamples * nSamples)
    null_lin = torch.from_numpy(np.setdiff1d(all_lin.cpu().numpy(), pos_lin.cpu().numpy())).long()

    npS = torch.remainder(null_lin, nSamples)
    npT = torch.div(torch.sub(null_lin, npS), nSamples)
    
    return torch.cat( (npT[..., None], npS[..., None]), dim=1).long()


In [None]:
## Twin network: the same network computes features on the source and target mesh
## and the loss function compares features at corresponding + non-corresponding points
def train(batch_step=1):

    model.train()
    
    # Sort out progress bar, displays average loss over last 15 samples
    wW = 15;
    window = torch.FloatTensor(wW).fill_(0)
    
    n_data = trainS_dataset.__len__()
    widgets = [progressbar.Percentage(), progressbar.Bar(), 
              progressbar.AdaptiveETA(), ' | ', progressbar.Variable('Loss'),]

    bar = progressbar.ProgressBar(max_value=n_data, widgets=widgets)

    # Manually shuffle data
    order = torch.randperm(n_data).long()
   
    ## Zero-out
    optimizer.zero_grad() 

    for i in range(n_data):
        
        # Load source and target meshes
        dataS = trainS_dataset.get(order[i])
        dataT = trainT_dataset.get(order[i])

        # Get non-corresponding pairs
        dataT.null_pairs = getNullPairs(dataT.pos_pairs)
        
        # Compute feature descriptors on source and target models
        FS = model(dataS.to(device))
        FT = model(dataT.to(device))
        
        # Randomly sample 512 pairs of corresponding and non-corresponding points each
        p_ = torch.randperm(dataT.pos_pairs.size(0))[:n_pairs];
        n_ = torch.randperm(dataT.null_pairs.size(0))[:n_pairs]
    
        # Compute loss
        L = loss(FS, FT, dataT.pos_pairs[p_, :], dataT.null_pairs[n_, :])
 
        # Update progress bar
        if (i < wW):
            window[i] = L.item()
            wAvg = torch.mean(window[:i])
        else:
            window = torch.cat((window[1:], torch.tensor([L.item() * batch_step])), dim=0)
            wAvg = torch.mean(window)
        
        bar.update(i+1, Loss=wAvg)
        
        # Update loss
        L = L / batch_step
        L.backward()
        
        if (i % batch_step == 0 or i == n_data):
            optimizer.step()
            model.zero_grad()


## Testing

In [None]:
## Compute the percentage of false positive and false negative matches on the test dataset
## You want both to decay for a good PR curve

def test():
    model.eval()
    
    n_false_null = 0;
    n_false_pos = 0;
    
    n_p = 0;
    n_n = 0;
        
    n_test = testS_dataset.__len__();
    for i in progressbar.progressbar(range(n_test)):
        with torch.no_grad():
            dataS = testS_dataset.get(i)
            dataT = testT_dataset.get(i)
            
            dataT.null_pairs = getNullPairs(dataT.pos_pairs)


            FS = model(dataS.to(device))
            FT = model(dataT.to(device))

            nFP, nFN = ev(FS, FT, dataT.pos_pairs, dataT.null_pairs)

            n_false_pos += nFP
            n_false_null += nFN

            n_p += dataT.pos_pairs.size(0)
            n_n += dataT.null_pairs.size(0)
   

    rateFP = n_false_pos / n_p
    rateFN = n_false_null / n_n
        
    return rateFP, rateFN

## Train, then test

In [None]:
print('Training...')
for epoch in range(80):
    
    print('Epoch {}'.format(epoch))
    train()
    
    ## Decay the learning rate after a while
    if (epoch == 40):
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.001
            
rateFP, rateFN = test() 
print("Test split: FP: {:06.4f}, FN: {:06.4f}, Err: {:06.4f}".format(rateFP, rateFN, rateFP + rateFN))