# Segmentation

### This notebook replicates the human body segmentation experiments  in section 6.3 of "Field Convolutions for Surface CNNs" (Mitchel et al. 2021).

## Dependencies

In [None]:
# File reading and progressbar
import os
import os.path as osp
import progressbar

# Numpy
import numpy as np

# PyTorch 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd

# PyTorch Geometric - used for data loading/processing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader

# Field Convolution modules
from nn import FCResNetBlock, LiftBlock, ECHOBlock, TangentPerceptron, LabelSmoothingLoss

# Transforms
from transforms import FCPrecomp, computeLogXPort, SupportGraph, NormalizeArea, NormalizeAxes

# Load the human body segmentation dataset (Maron et al., 2017)
from datasets import SHAPESEG

# Clear your cache
torch.cuda.empty_cache()

## Hyperparameters

In [None]:
# Band-limit for field convolution filters
band_limit = 2

# Number of radial samples
n_rings = 6

# Filter type (see /nn/field_conv.py)
ftype=1

# Number of channels in the network
nf = 48

# Number of ECHO descriptors to compute in the last layer of the network
n_des = 48;

# Number of descriptor bins per unit radius
# Descriptor resolution will be approximately  PI * (n_bins + 0.5) * (n_bins + 0.5)
n_bins = 3

# Filter support radius
epsilon = 0.2;

# Number of classes for segmentation
n_classes = 8

## Data pre-processing and loading

In [None]:
# Path where dataset is stored
# The path to the .zip file containing the data should be
# /data/SHAPESEG/raw/SHAPESEG.zip

path = osp.join('data', 'SHAPESEG')

# Pre-processing operations
# Normalize meshes to have unit surface area
# Sample points on meshes and compute convolution support edges
# Compute logarithm maps + parallel transport
pre_transform = T.Compose((
    NormalizeArea(),
    SupportGraph(epsilon=epsilon, sample_n=1024),
    computeLogXPort(),
    NormalizeAxes()
))

# Apply a random rotation and scale every time a shape is drawn from the dataloader
transform = T.Compose((
    T.RandomScale((0.85, 1.15)),
    T.RandomRotate(45, axis=0),
    T.RandomRotate(45, axis=1),
    T.RandomRotate(45, axis=2))
)

# Load test and train splits
test_dataset = SHAPESEG(path, False, pre_transform=pre_transform)
test_loader = DataLoader(test_dataset, batch_size=1)
train_dataset = SHAPESEG(path, True, pre_transform=pre_transform, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

## FCNet

##### A sucession of FCResNet blocks with an ECHO block as the final layer.  A learnable gradient-like operation is used to lift scalar features to isometry-equivariant tangent vector fields at the beginning of the network

In [None]:
## Organizes edge data at run time to expidte convolutions
organizeEdges = FCPrecomp(band_limit=band_limit, n_rings=n_rings, epsilon=epsilon)

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        ## Learned 'gradient', lifting scalar features to tangent vector features
        ## at the beginning of the network
        self.lift = LiftBlock(3, nf, n_rings=n_rings, ftype=ftype)
        
        ## FCNet - four FCResNet Blocks, followed by an ECHO block 

        self.resnet1 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
    
        self.resnet2 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet3 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)
        
        self.resnet4 = FCResNetBlock(nf, nf, band_limit=band_limit,
                                     n_rings=n_rings, ftype=ftype)

        self.echo = ECHOBlock(nf, n_classes, n_des=n_des, n_bins=n_bins, 
                             band_limit=band_limit, n_rings=n_rings, ftype=ftype)
        
        
        #self.mlp1 = TangentPerceptron(nf, nf)
        #self.mlp2 = TangentPerceptron(nf, nf)
        #self.mlp3 = TangentPerceptron(nf, nf)
 
        
    def forward(self, data):
        
        
        ##########################
        ### Organize edge data ###
        ##########################
        supp_edges, supp_sten, ln, wxp = organizeEdges(data)
        
        attr_lift = (supp_edges, supp_sten[..., band_limit:(band_limit+2)])
        attr_conv = (supp_edges, supp_sten)
        attr_echo = (supp_edges, supp_sten, ln, wxp)
        
        
        #############################################
        ### Lift scalar features to vector fields ###
        #############################################
        
        x = data.pos[data.sample_idx, :]
        
        x = self.lift(x, *attr_lift)
        
        #x = self.mlp3(self.mlp2(self.mlp1(x)))
        ##########################
        ### Field Convolutions ###
        ##########################
        
        x = self.resnet1(x, *attr_conv)

        x = self.resnet2(x, *attr_conv)
        
        x = self.resnet3(x, *attr_conv)
        
        x = self.resnet4(x, *attr_conv)

        
        
        #######################################################
        ### Compute ECHO descriptors and output predictions ###
        #######################################################
        
        return self.echo(x, *attr_echo)


## Training

In [None]:
# Train on the GPU
device = torch.device('cuda')
model = Net().to(device)

# ADAM Optimizer, lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

## Cross entropy label smoothing loss
loss = LabelSmoothingLoss(classes=n_classes, smoothing=0.2, dim=1)

## Can also try without smoothing
#loss = torch.nn.CrossEntropyLoss()



In [None]:
## Training function
## Optional batch_step parameter for gradient accumulation (not used in the paper)

def train(batch_step=1):

    model.train()
    
    # Sort out progress bar, displays average loss over last 50 samples
    wW = 50;
    window = torch.FloatTensor(wW).fill_(0)
    
    n_data = train_loader.__len__()
    widgets = [progressbar.Percentage(), progressbar.Bar(), 
              progressbar.AdaptiveETA(), ' | ', progressbar.Variable('Loss'),]

    bar = progressbar.ProgressBar(max_value=n_data, widgets=widgets)

    
    ## Zero-out
    optimizer.zero_grad() 

    i = 0;
    for data in train_loader:
            
        pred = model(data.to(device))

        L = loss(pred, data.y.to(device)) 
        if (i < wW):
            window[i] = L.item()
            wAvg = torch.mean(window[:i])
        else:
            window = torch.cat((window[1:], torch.tensor([L.item() * batch_step])), dim=0)
            wAvg = torch.mean(window)

        # Update progress bar
        i = i + 1
        bar.update(i, Loss = torch.mean(window[:i]))
        
        ## Update loss
        L = L / batch_step;
        L.backward()

        if ( i % batch_step == 0 or i == n_data ):
            optimizer.step()
            model.zero_grad()


## Testing

In [None]:
## Overally segmentation accuracy on the test dataset
def test():
    model.eval()
    correct = 0
    total_num = 0
    for i, data in enumerate(test_loader):
        pred = F.log_softmax(model(data.to(device)),dim=1).max(1)[1]
        correct += pred.eq(data.y).sum().item()
        total_num += data.y.size(0)
            
    acc = correct / total_num

    return acc

## Train, then test
###### We train for 15 epochs  as in the paper. 

In [None]:
print('Training...')

for epoch in range(15):
    
    print("Epoch {}".format(epoch), flush=True)
    train()
  
    
acc = test()
print("Test accuracy: {:06.4f}".format(acc), flush=True)
