# Correspondence

The notebooks in this folder replicate the experiments as performed for [CNNs on Surfaces using Rotation-Equivariant Features](https://doi.org/10.1145/3386569.3392437).

The current notebook replicates the Correspondence experiments from section `5.2 Comparisons`.

## Imports
We start by importing dependencies.

In [1]:
# File reading and progressbar
import os.path as osp
import progressbar

# PyTorch and PyTorch Geometric dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.io import read_off
import trimesh as tm
import numpy as np


# Harmonic Surface Networks components
# Layers
from nn import (ECResNetBlock, ExtConvCplx, LiftBlock, ECHOBlock, VectorDropout,
                ParallelTransportPool, ParallelTransportUnpool,
                TangentLin, TangentNonLin, TangentPerceptron)
# Utility functions
from utils.harmonic import magnitudes, norm2D
# Transforms
from transforms import (HarmonicPrecomp, VectorHeat, MultiscaleRadiusGraph,
                        NormalizeArea, ScaleMask, FilterNeighbours, Subsample)

## Settings
Next, we set a few parameters for our network. You can change these settings to experiment with different configurations of the network. Right now, the settings are set to the ones used in the paper.

In [2]:
# Band-limit for extended convolution
band_limit = 1

# Number of rings in the radial profile
n_corr_rings = 2

# Number of conv rings
n_conv_rings = 3

# Learn radial offset for correlations
offset = True;


# Number of filters per block
nf = [32, 32]

n_des = 12
n_bins = 2

desDim = 32

# Ratios used for pooling
ratios=[1, 0.25]

# Number of datasets per batch
batch_size = 1

# Remeshed dataset or not:
ForS = True

# Toggle between original and remeshed FAUST dataset
from datasets import FAUSTRemeshed as FAUST
faust_dir = 'FAUST_5k'
radii = [0.0425, 0.2] #radii = [0.07, 0.14]


## Dataset
To get our dataset ready for training, we need to perform the following steps:
1. Provide a path to load and store the dataset.
2. Define transformations to be performed on the dataset:
    - A transformation that computes a multi-scale radius graph and precomputes the logarithmic map.
    - A transformation that masks the edges and vertices per scale and precomputes convolution components.
3. Assign and load the datasets.

In [3]:
# 1. Provide a path to load and store the dataset.
# Make sure that you have created a folder 'data' somewhere
# and that you have downloaded and moved the raw datasets there
path = osp.join('data', faust_dir)

# 2. Define transformations to be performed on the dataset:
# Transformation that computes a multi-scale radius graph and precomputes the logarithmic map.
pre_transform = T.Compose((
    MultiscaleRadiusGraph(ratios, radii, loop=True, flow='target_to_source'),
    VectorHeat(max_lvl=len(ratios)-1),
    Subsample())

)

# Apply a random scale and random rotation to each shape
#transform = None


transform = T.Compose((T.Center(),
    T.RandomRotate(45, axis=0),
    T.RandomRotate(45, axis=1),
    T.RandomRotate(45, axis=2))
)



# Transformations that mask the edges and vertices per scale and precomputes convolution components.
scale0_transform = T.Compose((
    FilterNeighbours(radius=radii[0]),
    FCPrecomp(n_conv_rings, n_corr_rings, band_limit, max_r=radii[0]))
)


# 3. Assign and load the datasets.
train_dataset = FAUST(path, True, pre_transform=pre_transform, transform=transform)
test_dataset = FAUST(path, False, pre_transform=pre_transform, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Load template shape for evaluation
pathT = osp.join(path, 'raw/template/tr_reg_000.off')

    
template = read_off(pathT)
n_classes = 4999
pos_tem = template.pos.cpu().numpy()
faces_tem = template.face.cpu().numpy().T
num_nodes = train_dataset[0].num_nodes




## Network architecture
Now, we create the network architecture by creating a new `nn.Module`, `Net`. We first setup each layer in the `__init__` method of the `Net` class and define the steps to perform for each batch in the `forward` method. The following figure shows a schematic of the architecture we will be implementing:

<img src="img/resnet_architecture.png" width="800px" />

Let's get started!

In [4]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        

        self.lift = LiftBlock(3, 16, n_corr_rings, offset, MLP=False)
                
        self.resnet1 = ECResNetBlock(16, nf[0], band_limit, n_conv_rings)
                
        self.resnet2 = ECResNetBlock(nf[0], nf[0], band_limit, n_conv_rings)
        
        self.resnet3 = ECResNetBlock(nf[0], nf[0], band_limit, n_conv_rings)
        
        self.resnet4 = ECResNetBlock(nf[0], nf[0], band_limit, n_conv_rings)

        self.resnet5 = ECResNetBlock(nf[0], nf[0], band_limit, n_conv_rings)
        
        self.resnet6 = ECResNetBlock(nf[0], nf[0], band_limit, n_conv_rings)
        
        self.resnet7 = ECResNetBlock(nf[0], nf[0], band_limit, n_conv_rings)
        
        self.resnet8 = ECResNetBlock(nf[0], 16, band_limit, n_conv_rings, back=True)

        
        self.echo = ECHOBlock(16, n_des, n_bins, desDim, band_limit, n_conv_rings, classify=False, mlpC=[128, 64])


        self.D = nn.Dropout(p=0.5)
        
        self.d = VectorDropout(p=0.0)

        
        self.res1 = TangentPerceptron(16, nf[0])
        
        self.res2 = TangentPerceptron(nf[0], nf[0])
        
        self.res3 = TangentPerceptron(nf[0], nf[0])
        
        self.res4 = TangentPerceptron(nf[0], 16)

        

        self.lin1 = nn.Linear(32, 256)
        self.lin2 = nn.Linear(256, n_classes)

    def forward(self, data):
        
        ###############
        ### Level 1 ###
        ###############
        
        data_scale0 = scale0_transform(data)

        
        
        attr_grad = (data_scale0.edge_index, data_scale0.pcmp_gather)

        attr_conv = (data_scale0.edge_index, data_scale0.pcmp_scatter, 
                      data_scale0.connection)    
        
        attr_echo = (data_scale0.edge_index, data_scale0.pcmp_scatter, 
                     data_scale0.pcmp_echo, data_scale0.connection)
 
        x = data.pos
    
        #x = F.relu(self.lin0(x))
        #x = torch.cat((x[..., None], torch.zeros_like(x)[..., None]), dim=2)


    
        x1 = self.lift(x, *attr_grad)

        x = self.resnet1(x1, *attr_conv) 
        
        x2 = self.d(self.resnet2(x, *attr_conv) + self.res1(x1))
        
        x = self.resnet3(x2, *attr_conv)
        
        x3 = self.d(self.resnet4(x, *attr_conv) + self.res2(x2))
        
        x = self.resnet5(x3, *attr_conv)
        
        x4 = self.d(self.resnet6(x, *attr_conv) + self.res3(x3))
        
        x = self.resnet7(x4, *attr_conv)
        
        x = self.d(self.resnet8(x, *attr_conv) + self.res4(x4))
        
        x = self.echo(x, *attr_echo)
                    

        x = F.relu(self.lin1(x))
        x = self.D(x)
        x = self.lin2(x)
             
        return F.log_softmax(x, dim=1)

## Training

Phew, we're through the hard part. Now, let's get to training. First, move the network to the GPU and setup an optimizer.

In [5]:
# We want to train on a GPU. It'll take a long time on a CPU
#print(torch.cuda.memory_summary(device='cuda', abbreviated=False))
device = torch.device('cuda')
# Move the network to the GPU
model = Net().to(device)
# Set up the ADAM optimizer with learning rate of 0.0076 (as used in H-Nets)
rate = 0.01;
optimizer = torch.optim.Adam(model.parameters(), lr=rate)

decay1 = 0.96
decay2 = 0.99

# The target should be the index of each vertex
#target = torch.arange(num_nodes, dtype=torch.long, device=device)

In [6]:
def test(saveFile=None):
    # Set model to 'evaluation' mode
    model.eval()
    meanError = 0
    totalNodes = 0;
    
    count = 0;
    for i, data in progressbar.progressbar(enumerate(test_loader)):
        if not ForS:
            data.y = torch.sub(data.y, 1)
            
        with torch.no_grad():
            model.eval()

            pred = model(data.to(device)).max(1)[1].long()
        
            #pred = gt_keys[pred]
            #truth = gt_keys[data.y.squeeze()]
        
        

            error = tm.computeError(pos_tem, faces_tem, pred.cpu().numpy(), data.y.squeeze().cpu().numpy())
            if (count == 0):
                E = error;
            else:
                E = np.vstack((E, error))
            #correct += pred.eq(target).sum().item()
            count = count + 1;
            meanError += error.sum()
            totalNodes += data.pos.size(0);
        
    if saveFile is not None:
        np.save(saveFile, E)
    return meanError / totalNodes;

Next, define a training and test function.

In [7]:
def train(epoch):
    # Set model to 'train' mode
    model.train()
    
    '''
    if (epoch > 0 and epoch <= 50):
        for param_group in optimizer.param_groups:
            param_group['lr'] = rate * np.power(decay1, epoch) #param_group['lr'] * decay;
            
    if (epoch > 50):
        for param_group in optimizer.param_groups:
            param_group['lr'] = rate * np.power(decay1, 50) * np.power(decay2, epoch-50)  
            
    '''
            
    if epoch > 40:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.001
            
    if epoch > 60: 
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.0005
                

    
            
    # Sort out progress bar
    n_data = train_loader.__len__()
    widgets = [progressbar.Percentage(), progressbar.Bar(), 
              progressbar.AdaptiveETA(), ' | Loss:', progressbar.Variable('loss'),]

    bar = progressbar.ProgressBar(max_value=n_data, widgets=widgets)

    i = 0;
    totalL = 0.0;
    for data in train_loader:
    # Move training data to the GPU and optimize parameters
        #with autograd.detect_anomaly():
            data = data.to(device)
            optimizer.zero_grad()
            
            if not ForS:
                data.y = torch.sub(data.y, 1)
                

            L = F.nll_loss(model(data), data.y.squeeze())
            
            i = i + 1
            totalL += L.item()
            
            bar.update(i, loss = (totalL / i) )
            L.backward()
            
            optimizer.step()
    
    if (epoch > 0 and (epoch % 10) == 0):
        test_acc = test(None)
        print('Test: {:.6f}'.format(test_acc * np.sqrt(np.pi)))
        

Train for 100 epochs.

In [8]:
#sP= '/home/tommy/Dropbox/specialMath/Harmonic/ECHONet/V7/plot/FAUST_5K/states/VFC_xyz_FULL_60'
#model.load_state_dict(torch.load(sP))

for epoch in range(80):
    print('Epoch {}'.format(epoch))
    train(epoch)

Epoch 0


100%|#######################################|ETA:  00:00:00 | Loss:loss:   6.78

Epoch 1


100%|#######################################|ETA:  00:00:00 | Loss:loss:   4.99

Epoch 2


100%|#######################################|ETA:  00:00:00 | Loss:loss:    3.9

Epoch 3


100%|#######################################|ETA:  00:00:00 | Loss:loss:   3.32

Epoch 4


100%|#######################################|ETA:  00:00:00 | Loss:loss:   2.95

Epoch 5


100%|#######################################|ETA:  00:00:00 | Loss:loss:   2.67

Epoch 6


100%|#######################################|ETA:  00:00:00 | Loss:loss:   2.44

Epoch 7


100%|#######################################|ETA:  00:00:00 | Loss:loss:   2.29

Epoch 8


100%|#######################################|ETA:  00:00:00 | Loss:loss:   2.13

Epoch 9


100%|#######################################|ETA:  00:00:00 | Loss:loss:   2.06

Epoch 10


| |                                         #        | 19 Elapsed Time: 0:01:44


Test: 0.054833
Epoch 11


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.87

Epoch 12


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.83

Epoch 13


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.79

Epoch 14


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.72

Epoch 15


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.68

Epoch 16


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.63

Epoch 17


100%|#######################################|ETA:  00:00:00 | Loss:loss:    1.6

Epoch 18


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.58

Epoch 19


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.55

Epoch 20


| |                                                # | 19 Elapsed Time: 0:01:44


Test: 0.028119
Epoch 21


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.49

Epoch 22


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.49

Epoch 23


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.48

Epoch 24


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.44

Epoch 25


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.43

Epoch 26


100%|#######################################|ETA:  00:00:00 | Loss:loss:    1.4

Epoch 27


100%|#######################################|ETA:  00:00:00 | Loss:loss:    1.4

Epoch 28


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.38

Epoch 29


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.36

Epoch 30


| |                                           #      | 19 Elapsed Time: 0:01:44


Test: 0.024442
Epoch 31


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.38

Epoch 32


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.37

Epoch 33


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.33

Epoch 34


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.34

Epoch 35


100%|#######################################|ETA:  00:00:00 | Loss:loss:    1.3

Epoch 36


100%|#######################################|ETA:  00:00:00 | Loss:loss:    1.3

Epoch 37


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.29

Epoch 38


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.27

Epoch 39


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.27

Epoch 40


| |                                         #        | 19 Elapsed Time: 0:01:44


Test: 0.021844
Epoch 41


100%|#######################################|ETA:  00:00:00 | Loss:loss:    1.1

Epoch 42


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.03

Epoch 43


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.02

Epoch 44


100%|#######################################|ETA:  00:00:00 | Loss:loss:   1.01

Epoch 45


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.994

Epoch 46


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.988

Epoch 47


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.979

Epoch 48


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.974

Epoch 49


100%|#######################################|ETA:  00:00:00 | Loss:loss:   0.97

Epoch 50


| |                                              #   | 19 Elapsed Time: 0:01:45


Test: 0.018272
Epoch 51


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.959

Epoch 52


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.952

Epoch 53


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.948

Epoch 54


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.944

Epoch 55


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.941

Epoch 56


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.937

Epoch 57


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.934

Epoch 58


100%|#######################################|ETA:  00:00:00 | Loss:loss:   0.93

Epoch 59


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.925

Epoch 60


| |                                               #  | 19 Elapsed Time: 0:01:45


Test: 0.018430
Epoch 61


100%|#######################################|ETA:  00:00:00 | Loss:loss:   0.91

Epoch 62


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.905

Epoch 63


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.903

Epoch 64


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.902

Epoch 65


100%|#######################################|ETA:  00:00:00 | Loss:loss:    0.9

Epoch 66


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.898

Epoch 67


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.895

Epoch 68


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.895

Epoch 69


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.893

Epoch 70


| |                               #                  | 19 Elapsed Time: 0:01:43


Test: 0.018668
Epoch 71


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.891

Epoch 72


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.886

Epoch 73


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.887

Epoch 74


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.886

Epoch 75


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.884

Epoch 76


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.881

Epoch 77


100%|#######################################|ETA:  00:00:00 | Loss:loss:   0.88

Epoch 78


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.879

Epoch 79


100%|#######################################|ETA:  00:00:00 | Loss:loss:  0.877

In [9]:
## Plot
statePath = '/home/tommy/Dropbox/specialMath/Harmonic/ECHONet/V7/plot/FAUST_5K/states/VFC_xyz_FULL_3'
torch.save(model.state_dict(), statePath) 

## Testing
Finally, we test our model on the test dataset. Setup a test function:

And return the accuracy on the test set. Note that the exact accuracy for Remeshed FAUST will be very low. The princeton benchmark (see paper, fig 9) gives a better picture of the accuracy of the method in practice.

In [10]:
if ForS:
    savePath = '/home/tommy/Dropbox/specialMath/Harmonic/ECHONet/V7/plot/FAUST_5K/'
else:
    savePath = '/home/tommy/Dropbox/specialMath/Harmonic/ECHONet/V7/plot/SCAPE_5K/'

ID = 'VFC_xyz_FULL_3'

rawFile = savePath + ID + '_raw.npy'

plotFile = savePath + ID + '_plot.txt'

nrmFile = savePath + ID + '_plot_nrm.txt'

In [11]:
statePath = '/home/tommy/Dropbox/specialMath/Harmonic/ECHONet/V7/plot/FAUST_5K/states/VFC_xyz_FULL_3'
model.load_state_dict(torch.load(statePath))

test_acc = test(rawFile)

print('Test: {:.6f}'.format(test_acc * np.sqrt(np.pi)))

| |                                        #         | 19 Elapsed Time: 0:01:44


Test: 0.018693


In [12]:
maxE = 0.2;
samples = 100;

rawError = np.sort(np.load(rawFile), axis=None);

rI = torch.nonzero(torch.from_numpy(rawError)).cpu().numpy().astype(int)

mD = np.amin(rawError[rI])

plotError = tm.errorToPlot(rawError, maxE, samples)

nrmError = tm.errorToPlot(rawError - mD, maxE, samples)

np.savetxt(plotFile, plotError, fmt='%f')

np.savetxt(nrmFile, nrmError, fmt='%f')