In [63]:
from model.graphconv import Conv
from utils.sampling import HealpixSampling
from model.unet import GraphCNNUnet
import torch

import numpy as np
import plotly.express as px

## 1. Convolution example

### 1.1 Spherical sampling
We first need to define the spherical sampling used by the convolution. It will create the laplacian of the spherical graph and, the spherical pooling operations, which is dependent on the spherical graph. We combine all these information into one SphericalSampling class. **To leverage the efficient proposed hemispherical convolution, we introduce the argument *use_hemisphere*, that will create the proposed hemispherical Laplacian.**

In [64]:
# Define the spherical sampling used for the spherical convolution.
# The Healpix sampling is convenient thanks to its hierarchical structures
# making it easier to define the pooling operation.
n_side = 8
depth = 2 # Number of hierarchical level to define for this sampling. It will automatically create the pooling operations.
patch_size = 5 # Spatial size of the sampling. Used to efficiently create the pooling operations.
sh_degree = 8
pooling_mode = 'average' # Choice between 'average' and 'max' pooling
pooling_name = 'mixed' # Choice between 'spatial', 'spherical', 'mixed', 'bekkers'
use_hemisphere = True # If True, the sampling will only consider the upper hemisphere.

sampling = HealpixSampling(n_side, depth, patch_size, sh_degree, pooling_mode, pooling_name, use_hemisphere)
# Access the laplacians and pooling of the sampling
laps = sampling.laps # List of the laplacians from lowest to highest resolution sampling
pools = sampling.pooling # List of the poolings from lowest to highest resolution sampling

--------------------------------------------------
----------  Create Healpix Sampling  ----------
Healpix resolution: 8
Depth: 2
Patch size: 5
Spherical harmonic degree: 8
Pooling mode: average
Pooling name: mixed
Hemisphere: True
Legacy: False
Sampling number SHC: 45
Create Laplacians
Laplacian at depth 0: torch.Size([384, 384]) and coordinates: (384, 3)
Laplacian at depth 1: torch.Size([96, 96]) and coordinates: (96, 3)
Create Poolings
Initial size: Patch size: 5 - Resolution: 8
5  -  8
Pooling after depth 0: <utils.pooling.MixedPooling object at 0x7f6f40507380> - Patch size: 4 - Resolution: 4
--------------------------------------------------


### 1.2. Single convolution layer

Once we know the spherical/hemispherical graph laplacian, we create our convolution operator. In addition to the proposed convolution, we provide the spherical and spatial convolutions, as well as the proposed spatio-spherical convolution from PONITA [Bekkers '24](https://github.com/ebekkers/ponita). All these convolutions can work with either the full spherical sampling or the more efficient hemispherical sampling, depending on the argument choice *use_hemisphere* in the previous step. **Notice that you can pick an anisotropic spatial kernel by setting the argument *isoSpa* to False, as well as disabling the dense matrix multiplication used in the spherical / hemispherical operation.**

In [65]:
# Define convolution layers for the highest resolution sampling we previously defined
in_channels = 2 # Number of input channel
out_channels = 2 # Number of output channel

kernel_sizeSph = 3 # Spherical kernel size
kernel_sizeSpa = 3 # Spatial kernel size

lap = laps[-1] # Laplacian of the spherical graph
bias = True # Add bias after convolution

# We implemented three convolutions
conv_name = 'spherical' # Can be 'spherical', 'spatial', 'mixed', 'bekkers'
dense = False # Use dense matrix multiplication in the spherical operation. Setting it to True improve efficiency.
isoSpa = True # Use isotropic filter.
conv = Conv(in_channels, out_channels, lap, kernel_sizeSph, kernel_sizeSpa, bias, conv_name, isoSpa, dense)

Add convolution: spherical - in_channels: 2 - out_channels: 2 - lap: torch.Size([384, 384]) - kernel_sizeSph: 3 - kernel_sizeSpa: 3 - isoSpa: True


### 1.3 Forward pass

The convolution can readily be applied to a R3xS2 random signal.

In [66]:
# Generate a random R3 x S2 signal
batch_size = 2
# Convolution input should have size
# Batch x Feature Channel x Number of spherical vertice x Spatial patch size x Spatial patch size x Spatial patch size
x_conv = torch.rand(batch_size, in_channels, lap.shape[0], patch_size, patch_size, patch_size)

# Filtering the signal to reduce noise (not necessary, only to have better visualization)
x_sh = torch.einsum('bfvijk,vc->bfcijk', x_conv, torch.Tensor(sampling.sampling.S2SH))
x_conv = torch.einsum('bfcijk,cv->bfvijk', x_sh, torch.Tensor(sampling.sampling.SH2S))

# The spherical convolution is only applied to the spherical vertices
# without knowledge of the spatial neighborhoods
out_conv = conv(x_conv).detach().cpu()

In [67]:
grid = sampling.vec[-1].T

grid_cat = np.zeros((3, patch_size, patch_size, grid[0].shape[-1]))
for i in range(patch_size):
    for j in range(patch_size):
        #for l in range(grid_size):
        grid_cat[0, i, j] = grid[0] + 3*i
        grid_cat[1, i, j] = grid[1] + 3*j
        grid_cat[2, i, j] = grid[2]

for x_plot in [x_conv, out_conv]:
    fig = px.scatter_3d(x=grid_cat[0].flatten(), y=grid_cat[1].flatten(), z=grid_cat[2].flatten(), color=x_plot[0, 0, :, 0].flatten(), color_continuous_scale='RdBu_r')
    fig.update_traces(marker=dict(size=4))
    fig.update_layout(scene_aspectmode='data')

    fig.update_layout(
        width=600, height=600,
        margin=dict(t=0, r=0, l=0, b=0)
    )

    # Default parameters which are used when `layout.scene.camera` is not provided
    camera = dict(
        up=dict(x=0, y=0, z=10),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=0, y=0, z=3)
    )
    fig.update_scenes(xaxis_showgrid=False)
    fig.update_scenes(yaxis_showgrid=False)
    fig.update_scenes(zaxis_showgrid=False)
    fig.update_scenes(xaxis_showbackground=False)
    fig.update_scenes(yaxis_showbackground=False)
    fig.update_scenes(zaxis_showbackground=False)
    fig.update_scenes(xaxis_showaxeslabels=False)
    fig.update_scenes(yaxis_showaxeslabels=False)
    fig.update_scenes(xaxis_showticklabels=False)
    fig.update_scenes(yaxis_showticklabels=False)
    fig.update_scenes(zaxis_showticklabels=False)
    fig.update_scenes(xaxis_showticksuffix='none')
    fig.update_scenes(yaxis_showticksuffix='none')
    fig.update_scenes(zaxis_showticksuffix='none')
    fig.update_layout(scene_camera=camera)
    fig.update(layout_coloraxis_showscale=False)
    fig.show()


## 2. Spatio-Hemispherical Unet

Again, we first define the spherical sampling.

In [68]:
# Define the spherical sampling used for the spherical convolution.
# The Healpix sampling is convenient thanks to its hierarchical structures
# making it easier to define the pooling operation.
n_side = 8
depth = 2 # Number of hierarchical level to define for this sampling. It will automatically create the pooling operations.
patch_size = 5 # Spatial size of the sampling. Used to efficiently create the pooling operations.
sh_degree = 8
pooling_mode = 'average' # Choice between 'average' and 'max' pooling
pooling_name = 'mixed' # Choice between 'spatial', 'spherical', 'mixed', 'bekkers'
use_hemisphere = True # If True, the sampling will only consider the upper hemisphere.

sampling = HealpixSampling(n_side, depth, patch_size, sh_degree, pooling_mode, pooling_name, use_hemisphere)
# Access the laplacians and pooling of the sampling
laps = sampling.laps # List of the laplacians from lowest to highest resolution sampling
pools = sampling.pooling # List of the poolings from lowest to highest resolution sampling
vecs = sampling.vec 
patch_size_list = sampling.patch_size_list
nvec_out = sampling.sampling.vectors.shape[0]

--------------------------------------------------
----------  Create Healpix Sampling  ----------
Healpix resolution: 8
Depth: 2
Patch size: 5
Spherical harmonic degree: 8
Pooling mode: average
Pooling name: mixed
Hemisphere: True
Legacy: False
Sampling number SHC: 45
Create Laplacians
Laplacian at depth 0: torch.Size([384, 384]) and coordinates: (384, 3)
Laplacian at depth 1: torch.Size([96, 96]) and coordinates: (96, 3)
Create Poolings
Initial size: Patch size: 5 - Resolution: 8
5  -  8
Pooling after depth 0: <utils.pooling.MixedPooling object at 0x7f6f3b6ff3b0> - Patch size: 4 - Resolution: 4
--------------------------------------------------


We then create the UNet.

In [69]:
in_channels = 1 # Number of input channel
out_channels = 5 # Number of output channel
filter_start = 1 # Number of filters after the first convolution. Then, number of filter double after each pooling
block_depth = 1 # Number de block(convolution + bn + activation per) between two poolings for encoder
in_depth = 1 # Number de block(convolution + bn + activation per) before unpooling for decoder
kernel_sizeSph = 3 # Spherical kernel size
kernel_sizeSpa = 3 # Spatial kernel size
poolings = pools # List of poolings
laps = laps # List of laplacians
conv_name = pooling_name # Name of the convolution
isoSpa = True # Use istropic spatial filter to get E3 equivariance
keepSphericalDim = True # For output, keep the spherical dimension or global average across vertices
patch_size_list = patch_size_list # List of spatial patch size of each layer to ensure correct kernel size
vec = vecs # list of input vertex coordinates. Used only for bekkers convolution
nvec_out = nvec_out # Number of output vertices. Only used for spatial convolution.

unet = GraphCNNUnet(in_channels, out_channels, filter_start, block_depth, in_depth, kernel_sizeSph, kernel_sizeSpa, poolings, laps, conv_name, isoSpa, keepSphericalDim, patch_size_list, vec, nvec_out)

# Generate a random R3xS2 signal
batch_size = 1
# Convolution input should have size
# Batch x Feature Channel x Number of spherical vertice x Spatial patch size x Spatial patch size x Spatial patch size
x = torch.rand(batch_size, in_channels, laps[-1].shape[0], patch_size, patch_size, patch_size) # B x F_in x V x P x P x P

y = unet(x).detach().cpu() # B x F_out x (V or 1) x P x P x P

Unet in channels: 1 - out channels 5 - nvecs out 384
Create Encoder with 1 levels and 1 convolution per level
Add convolution: mixed - in_channels: 1 - out_channels: 1 - lap: torch.Size([384, 384]) - kernel_sizeSph: 3 - kernel_sizeSpa: 3 - isoSpa: True
Create Decoder with 2 levels and 2 convolution per level and a head with 1 convolution
Add convolution: mixed - in_channels: 1 - out_channels: 2 - lap: torch.Size([96, 96]) - kernel_sizeSph: 3 - kernel_sizeSpa: 3 - isoSpa: True
Add convolution: mixed - in_channels: 2 - out_channels: 1 - lap: torch.Size([96, 96]) - kernel_sizeSph: 3 - kernel_sizeSpa: 3 - isoSpa: True
Add convolution: mixed - in_channels: 2 - out_channels: 1 - lap: torch.Size([384, 384]) - kernel_sizeSph: 3 - kernel_sizeSpa: 3 - isoSpa: True
Add convolution: mixed - in_channels: 1 - out_channels: 1 - lap: torch.Size([384, 384]) - kernel_sizeSph: 3 - kernel_sizeSpa: 3 - isoSpa: True
Add convolution: mixed - in_channels: 1 - out_channels: 5 - lap: torch.Size([384, 384]) - ke