# Mobius Convolution Demo

### This notebook provides a demonstration of the basic functionality of Mobius convolution modules from the paper ''Mobius Convolutions for Spherical CNNs''  (Mitchel et al. 2022).

In [None]:
# Imports
import torch
import os

# Empty torch cache
torch.cuda.empty_cache()

## Hyperparameters

In [None]:
#####################################
####### Device, checkpointing #######
#####################################

device = torch.device('cuda')

# Checkpointing
checkpoint = True;

In [None]:
#####################################
### Resolution, dimensionality ######
#####################################

## Band-limit of the spherical grid, resulting in a 2*B x 2*B spherical image
## See TS2Kit documentation for more information
B = 64; 

# Number of channels
C = 16;

# Batch size
batch_size = 1;


In [None]:
###################################
############# Filters #############
###################################

## The following two parameters control the number of learnable parameters per-filter
## Larger values mean higher-resolution filters, but more memory overhead
## We used the following values for the experiments in the paper

## Radial band-limit of learnable log-polar filters
D1 = 1;

## Angular band-limit of learnable log-polar filters
D2 = 1;

## The following two parameters control the quality of the discretized representation
## Larger values mean better accuracy, but more memory overhead
## The following values were used in the experiments in the paper

## Angular band-limit of representation
M = D2 + 1;

## Number of radial quadrature samples in representation
Q = 30;

## Cache

In [None]:
## Several tensors are pre-computed at initialization and at higher bandlimits
## this can take some time. To avoid re-computing these quantities every initialization,
## the modules will check if the tensors have been saved in the cache/files directory and either 
## A). load the tensors directly from the cache; or B). compute the tensors and save them 
## to the cache directory so they can be loaded next time the modules are initialized. 

## The cache directory can be cleared of .pt files at anytime via the following:


from cache.cache import clearCache

#clearCache();


## Mobius Convolution modules

In [None]:
from nn import MCResNetBlock, MobiusConv
from utils.rs_diff import rsDirichlet

## The principal module is the Mobius convolution ResNet block, two Mobius convolutions, each followed
## by a Dirichlet-energy Filter Response normalization + nonlinearity, with a residual connection between
## the input and output streams

## On the first initalization (or after clearing the cache), various quantities will be precomputed and
## saved in the cache directory to be loaded on subsequent initalizations

## Example: Initialize a MCResNet Block (Reccomended for general use)
#MCRN = MCResNetBlock(C, C, B, D1, D2, M, Q, checkpoint=checkpoint)

## Here we'll use a simple MobiusConv with the output normalized by the Dirichlet energy (without the thresholded
## nonlinearity to better demonstrate equivariance)

class MCLayer(torch.nn.Module):

    def __init__(self):
        super().__init__()
    
        self.conv = MobiusConv(C, C, B, D1, D2, M, Q);
        
        self.E = rsDirichlet(B);
        
        
    def forward(self, x):
        
        xC = self.conv(x)
        
        return xC /  torch.sqrt(self.E(xC)[..., None, None] + 1.0e-6)


## Equivariance demo

In [None]:
## We'll compare Mobius Convolution against a standard 2D convolution layer, with output normalized 
## by the Dirichlet energy
class Conv2dLayer(torch.nn.Module):

    def __init__(self, k=7):
        super().__init__()
    
        p = (k - 1) // 2;
        
        self.conv = torch.nn.Conv2d(C, C, kernel_size=k, padding=p, bias=False)

        self.E = rsDirichlet(B);
        
    def forward(self, x):
                        
        xC = self.conv(x);
        
        return xC /  torch.sqrt(self.E(xC)[..., None, None] + 1.0e-6)


In [None]:
## We can compare the equivariance error between Mobius convolutions and a standard Conv2d module
## by measuring how much each modules commutes with a Mobius transformation

from utils.demo import randSignal, bilinearInterpolant, randMobius
from utils.rs_diff import rsNorm2
from TS2Kit.ts2kit import FTSHT

randM = randMobius(B)

norm2 = rsNorm2(B);
SHT = FTSHT(B);


In [None]:
## Regular 2D convolution module
RN = Conv2dLayer().to(device)
RN.eval()

## Mobius Conv module
MC = MCLayer().float().to(device)
MC.eval()


## Draw random mobius transformation
thetaM, phiM = randM(0.3) # Maximum scale factor ~ 12
interp = bilinearInterpolant(thetaM, phiM)

## Draw random signal
x = randSignal(batch_size, B, C).real;

## Transform it
gx = interp(x)


## Compare Conv2d outputs
Cx = RN(x.to(device)).to('cpu');

diffC2d = interp(Cx) - RN(gx.to(device)).to('cpu')

EVC2d = torch.sum(norm2(diffC2d.squeeze()).squeeze(), dim=0) / C;

muC2d = torch.sum(SHT(Cx.squeeze(0))[:, (B-1), 0].real, dim=0).item() / C;

VarC2d= torch.sum(norm2(Cx.squeeze(0) - muC2d), dim=0) / C;

errorC2d = EVC2d / VarC2d;


## Compare MC outputs
MCx = MC(x.to(device)).to('cpu')
diffMC =  interp(MCx) - MC(gx.to(device)).to('cpu') 

EVMC = torch.sum(norm2(diffMC.squeeze()), dim=0) / C

muMC = torch.sum(SHT(MCx.squeeze(0))[:, (B-1), 0].real, dim=0).item() / C;

VarMC = torch.sum(norm2(MCx.squeeze(0) - muMC), dim=0) / C;

errorMC = EVMC / VarMC;

del RN, MC

print("Conv2d layer commutativity error = {}".format(errorC2d), flush=True)
print("Mobius Convolution layer commutativity error = {}.".format(errorMC), flush=True)