# Rotated MNIST

The notebooks in this folder replicate the experiments as performed for [CNNs on Surfaces using Rotation-Equivariant Features](https://doi.org/10.1145/3386569.3392437).

The current notebook replicates the Rotated MNIST experiments from section `5.3 Evaluation`.

## Imports
We start by importing dependencies.

In [None]:
# File reading and progressbar
import os.path as osp
import progressbar

# PyTorch and PyTorch Geometric dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.inits import zeros
from torch_geometric.data import DataLoader

# Harmonic Surface Networks components
# Layers
from nn import HarmonicConv, ComplexNonLin, ParallelTransportPool
# Utility functions
from utils.harmonic import magnitudes, c_batch_norm
# Rotated MNIST dataset
from datasets import MNISTSphere
# Transforms
from transforms import VectorHeat, HarmonicPrecomp, MultiscaleRadiusGraph, ScaleMask

## Settings
Next, we set a few parameters for our network. You can change these settings to experiment with different configurations of the network. Right now, the settings are set to the ones used in the paper.

In [None]:
# Maximum rotation order for streams
max_order = 1

# Number of rings in the radial profile
n_rings = 4

# Number of filters per block
nf = [8, 16, 32]

# Ratios used for pooling
ratios=[1, 0.5, 0.25]

# Radius of convolution for each scale
radii = [0.3, 0.45, 0.8]

# Number of datasets per batch
batch_size = 32

## Dataset
To get our dataset ready for training, we need to perform the following steps:
1. Provide a path to load and store the dataset.
2. Define transformations to be performed on the dataset:
    - A transformation that computes a multi-scale radius graph and precomputes the logarithmic map.
    - A transformation that masks the edges and vertices per scale and precomputes convolution components.
3. Assign and load the datasets.

Note that for the sphere, the logarithmic map and precomputed components are equal for every datapoint. Hence, we cache these computations and reuse them for every shape. To store the caches, we create a cache folder first.

In [None]:
!mkdir cache

In [None]:
# 1. Provide a path to load and store the dataset.
# Make sure that you have created a folder 'data' somewhere
# and that you have downloaded and moved the raw dataset there
path = osp.join('data', 'MNISTSphere')

# 2. Define transformations to be performed on the dataset:
# Transformation that computes a multi-scale radius graph and precomputes the logarithmic map.
transform = T.Compose((
    MultiscaleRadiusGraph(ratios, radii, loop=True, flow='target_to_source', cache_file='cache/radius_sphere.pt'),
    VectorHeat(cache_file='cache/logmap_sphere.pt')
))
# Transformations that mask the edges and vertices per scale and precomputes convolution components.
scale0_transform = T.Compose((
    ScaleMask(0),
    HarmonicPrecomp(n_rings, max_order, max_r=radii[0], cache_file='cache/p0_sphere.pt')
))
scale1_transform = T.Compose((
    ScaleMask(1),
    HarmonicPrecomp(n_rings, max_order, max_r=radii[1], cache_file='cache/p1_sphere.pt')
))
scale2_transform = T.Compose((
    ScaleMask(2),
    HarmonicPrecomp(n_rings, max_order, max_r=radii[2], cache_file='cache/p2_sphere.pt')
))

# 3. Assign and load the datasets.
train_dataset = MNISTSphere(path, True, transform=transform)
test_dataset = MNISTSphere(path, False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Network architecture
Now, we create the network architecture by creating a new `nn.Module`, `Net`. We first setup each layer in the `__init__` method of the `Net` class and define the steps to perform for each batch in the `forward` method. The following figure shows a schematic of the architecture we will be implementing:

<img src="img/classification_architecture.png" width="500px" />

Let's get started!

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # Block 1, scale 0
        # Because we start with an m=0 input, we set prev_order to 0
        self.conv1 = HarmonicConv(1, nf[0], max_order, n_rings, prev_order=0)
        self.nonlin1 = ComplexNonLin(nf[0])
        self.conv2 = HarmonicConv(nf[0], nf[0], max_order, n_rings)
        self.bn1 = nn.BatchNorm1d((max_order + 1) * nf[0], eps=1e-3, momentum=0.01)

        # Pooling to scale 1
        self.pool1 = ParallelTransportPool(1, scale1_transform)

        # Block 2, scale 1
        self.conv3 = HarmonicConv(nf[0], nf[1], max_order, n_rings)
        self.nonlin3 = ComplexNonLin(nf[1])
        self.conv4 = HarmonicConv(nf[1], nf[1], max_order, n_rings)
        self.bn2 = nn.BatchNorm1d((max_order + 1) * nf[1], eps=1e-3, momentum=0.01)

        # Pooling to scale 2
        self.pool2 = ParallelTransportPool(2, scale2_transform)

        # Block 3, scale 2
        self.conv5 = HarmonicConv(nf[1], nf[2], max_order, n_rings)
        self.nonlin5 = ComplexNonLin(nf[2])
        self.conv6 = HarmonicConv(nf[2], nf[2], max_order, n_rings)
        self.bn3 = nn.BatchNorm1d((max_order + 1) * nf[2], eps=1e-3, momentum=0.01)

        # Final Harmonic Convolution
        # We set offset to False, 
        # because we will only use the radial component of the features after this
        self.conv7 = HarmonicConv(nf[2], 10, max_order, n_rings, offset=False)
        self.bias = nn.Parameter(torch.Tensor(10))
        zeros(self.bias)

    def forward(self, data):
        # The input x is fed to our convolutional layers as a complex number and organized by rotation orders.
        # Resulting matrix: [batch_size, max_order + 1, channels, complex]
        x = torch.stack((data.x, torch.zeros_like(data.x)), dim=-1).unsqueeze(1)
        batch_size = data.num_graphs
        n_nodes = x.size(0)

        # Block 1, scale 0
        # Mask correct edges and nodes
        data_scale0 = scale0_transform(data)
        # Get edge indices and precomputations for scale 0
        attributes = (data_scale0.edge_index, data_scale0.precomp, data_scale0.connection)
        # Apply convolutions
        x = self.conv1(x, attributes[0], attributes[1])
        x = self.nonlin1(x)
        x = self.conv2(x, *attributes)
        x = c_batch_norm(x, batch_size, self.bn1, F.relu)
        
        # Pooling to scale 1
        x, data, data_pooled = self.pool1(x, data)
        # Get edge indices and precomputations for scale 1
        attributes_pooled = (data_pooled.edge_index, data_pooled.precomp, data_pooled.connection)

        # Block 2, scale 1
        x = self.conv3(x, *attributes_pooled)
        x = self.nonlin3(x)
        x = self.conv4(x, *attributes_pooled)
        x = c_batch_norm(x, batch_size, self.bn2, F.relu)

        # Pooling to scale 2
        x, data, data_pooled = self.pool2(x, data)
        # Get edge indices and precomputations for scale 2
        attributes_pooled = (data_pooled.edge_index, data_pooled.precomp, data_pooled.connection)

        # Block 3, scale 2
        x = self.conv5(x, *attributes_pooled)
        x = self.nonlin5(x)
        x = self.conv6(x, *attributes_pooled)
        x = c_batch_norm(x, batch_size, self.bn3, F.relu)

        # Final convolution
        x = self.conv7(x, *attributes_pooled)
        # Take radial component of each complex feature
        x = magnitudes(x, keepdim=False)
        # Sum the two streams
        x = x.sum(dim=1)

        # Global mean pool to retrieve classification
        x = global_mean_pool(x, data.batch)
        x = x + self.bias
        return F.log_softmax(x, dim=1)

## Training

Phew, we're through the hard part. Now, let's get to training. First, move the network to the GPU and setup an optimizer.

In [None]:
# We want to train on a GPU. It'll take a long time on a CPU
device = torch.device('cuda')
# Move the network to the GPU
model = Net().to(device)
# Set up the ADAM optimizer with learning rate of 0.0076 (as used in H-Nets)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0076)

Next, define a training and test function.

In [None]:
import numpy as np

def train(epoch):
    # Set model to 'train' mode
    model.train()

    for param_group in optimizer.param_groups:
        # Slowly decrease the learning rate every epoch
        param_group['lr'] = param_group['lr'] * np.power(0.1, epoch / 50)

    for data in train_loader:
        # Move training data to the GPU and optimize parameters
        optimizer.zero_grad()
        F.nll_loss(model(data.to(device)), data.y).backward()
        optimizer.step()

Train for 100 epochs.

In [None]:
print('This may take a while...')
# Try with fewer epochs if you're in a timecrunch
for epoch in progressbar.progressbar(range(100), redirect_stdout=True):
    train(epoch)

## Testing
Finally, we test our model on the test dataset. Setup a test function:

In [None]:
def test():
    # Set model to 'evaluation' mode
    model.eval()
    correct = 0

    for data in progressbar.progressbar(test_loader):
        # Move test data to the GPU and return a prediction
        data = data.to(device)
        pred = model(data).max(1)[1]
        correct += pred.eq(data.y).sum().item()
    # Return the fraction of correctly classified shapes
    return correct / len(test_dataset)

And return the accuracy on the test set:

In [None]:
test_acc = test()
print('Test: {:.6f}'.format(test_acc))