# Image to Sphere Via Induced Represenations For Pose Estimation

This notebook introduces induced image to sphere (which we call Induced_I2S), for the orientation estimation step of single view pose prediction problems. Induced_I2S is based on the Image to Sphere (https://openreview.net/forum?id=_2bDpAtr7PI) architecture.
There are a few fundamental differences between I2S and Induced_I2S, which we enumerate here:
1. Instead of the orthographic projection used in I2S, Induced_I2S uses a fully differentiable induction layer, which accepts a c-channeled image and outputs a set of matrix valued spherical harmonic coefficients. The orthographic projection is a specific instance of the induction layer. Unlike the orthographic projection, the induction layer creates a signal that has non-zero support everywhere on the sphere.


2. The SO(3)-convolution is performed using the method of Saro et al. (https://arxiv.org/abs/2302.03655) which reduces the computational cost from L^{6} to L^{3} where L is the maximum total angular momentum. This signicantly reduces the computational cost of an SO(3) convolution. (To be implemented)

3. The is really no a priori reason why we need to induce from the plane to the sphere. We can also use the induced representation to map directly from the plane directly into SO(3).


# Conceptual Questions: To Do List:
1. Conceptual question: What should s2 representation be? Re-read spherical CNN paper as this discusses choosing optimal convolutions
2. Really need to include pyramid features to deal with discretization error. What is the best way to do this?
Specifically, we need to include both low resolution and high resolution discretization
3. There is no obvious reason why the sphere is needed. Can potentially go directly to SO(3)
4. What non-linearities should be put after induction layer? I.e. What non-linearities are commuting with induction map?

# Implementation: To Do List:
1. Put in the induction directly from plane to SO(3). The matrix coeficents satisfy different set of equations
2. Implement the Saro Convolution Method
3. Implement a multiscale archeteture. FPN arch

In [7]:
### import relevent packages
import torch
import numpy as np
from e2cnn import gspaces
from e2cnn import nn
from e2cnn import group
from e3nn import o3
import e3nn
import represenations_opps as rep_ops
import healpy as hp

import time
from torch.utils.data import Dataset
from torch.utils.data import TensorDataset, DataLoader

###import pickle5 as pickle
import pickle

Check for Avalible GPUs

In [8]:
### check cuda read
torch.cuda.is_available()

### use gpus if avalible, otherwise use cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### number of gpus avalible
num_gpu = torch.cuda.device_count() 

print("Number of GPUs avalible:", num_gpu )

### get names of gpus if avalible
#name = torch.cuda.get_device_name(0)
#print(name)

Number of GPUs avalible: 0


# Defining an SO(2) Convolution Layer

In [9]:
### defining a SO2 convolution layer
SO2_act = gspaces.Rot2dOnR2(N=-1,maximum_frequency=50)
class SO2_Convolution_Layer(torch.nn.Module):
    ''' SO2 convoluations
    :channels: Number of channels in image
    : image_shape : images must be square!
    : lmax maximum : maximum degree of SO(2) harmonics
    : rep_in  : input SO2 representation as list
    : rep_out : output SO2 representation as list
    '''
    def __init__(self, rep_in:list , rep_out:list , k_max:int , kernel_size:int ):
        super().__init__()
        self.kernel_size = kernel_size
        self.maximum_frequency = k_max
        
        ### Defining the SO(2) action on R^{2}
        ### set maximum frequency to 50
        SO2_act = gspaces.Rot2dOnR2(N=-1,maximum_frequency=self.maximum_frequency)
    
        ### input and output representations
        self.rep_in = rep_in
        self.rep_out = rep_out
        
        ### input and output feature types
        self.feat_type_in = nn.FieldType( SO2_act, self.rep_in  )
        self.feat_type_out = nn.FieldType( SO2_act, self.rep_out  )
        
        ### convolution
        self.conv = nn.R2Conv( self.feat_type_in, self.feat_type_out, kernel_size=self.kernel_size )
        
        ### non-linearity
        self.non_lin = nn.NormNonLinearity(self.feat_type_out)
        
    def forward(self, x):
        return self.non_lin( self.conv(x) )
    

In [10]:
### convert SO2 reps to e2cnn format
### This should be put in seperate script
def convert_SO2(  input_rep_dict ):
    total_rep = []
    for k in input_rep_dict.keys():
        mulplicites = input_rep_dict[k]        
        total_rep = total_rep + mulplicites * [SO2_act.irrep(int(k))]

    return total_rep

# Image2Sphere Orthographic Projection Baseline

In [11]:
def s2_healpix_grid(rec_level: int=0, max_beta: float=np.pi/6):
    """Returns healpix grid up to a max_beta
    """
    n_side = 2**rec_level
    npix = hp.nside2npix(n_side)
    m = hp.query_disc(nside=n_side, vec=(0,0,1), radius=max_beta)
    beta, alpha = hp.pix2ang(n_side, m)
    alpha = torch.from_numpy(alpha)
    beta = torch.from_numpy(beta)
    return torch.stack((alpha, beta)).float()

### image 2 sphere orthographic projection
class Image2SphereProjector(torch.nn.Module):
  
    def __init__(self,
               fmap_shape, 
               sphere_fdim: int,
               lmax: int,
               coverage: float = 0.9,
               sigma: float = 0.2,
               max_beta: float = np.radians(90),
               taper_beta: float = np.radians(75),
               rec_level: int = 2,
               n_subset: int = 20,
              ):
        super().__init__()
        self.lmax = lmax
        self.n_subset = n_subset

        # point-wise linear operation to convert to proper dimensionality if needed
        if fmap_shape[0] != sphere_fdim:
          self.conv1x1 = torch.nn.Conv2d(fmap_shape[0], sphere_fdim, 1)
        else:
          self.conv1x1 = torch.nn.Identity()

        # determine sampling locations for orthographic projection
        self.kernel_grid = s2_healpix_grid(max_beta=max_beta, rec_level=rec_level)
        self.xyz = o3.angles_to_xyz(*self.kernel_grid)

        # orthographic projection
        max_radius = torch.linalg.norm(self.xyz[:,[0,2]], dim=1).max()
        sample_x = coverage * self.xyz[:,2] / max_radius # range -1 to 1
        sample_y = coverage * self.xyz[:,0] / max_radius

        gridx, gridy = torch.meshgrid(2*[torch.linspace(-1, 1, fmap_shape[1])], indexing='ij')
        scale = 1 / np.sqrt(2 * np.pi * sigma**2)
        data = scale * torch.exp(-((gridx.unsqueeze(-1) - sample_x).pow(2) \
                                    +(gridy.unsqueeze(-1) - sample_y).pow(2)) / (2*sigma**2) )
        data = data / data.sum((0,1), keepdims=True)

        # apply mask to taper magnitude near border if desired
        betas = self.kernel_grid[1]
        if taper_beta < max_beta:
            mask = ((betas - max_beta)/(taper_beta - max_beta)).clamp(max=1).view(1, 1, -1)
        else:
            mask = torch.ones_like(data)

        data = (mask * data).unsqueeze(0).unsqueeze(0).to(torch.float32)
        self.weight = torch.nn.Parameter(data= data, requires_grad=True)

        self.n_pts = self.weight.shape[-1]
        self.ind = torch.arange(self.n_pts)

        self.register_buffer(
            "Y", o3.spherical_harmonics_alpha_beta(range(lmax+1), *self.kernel_grid, normalization='component')
        )

    def forward(self, x):
        '''
        :x: float tensor of shape (B, C, H, W)
        :return: feature vector of shape (B,P,C) where P is number of points on S2
        '''
        
        #### x.tensor
        x = self.conv1x1(x)

        if self.n_subset is not None:
            self.ind = torch.randperm(self.n_pts)[:self.n_subset]

        x = (x.unsqueeze(-1) * self.weight[..., self.ind]).sum((2,3))
        x = torch.relu(x)
        x = torch.einsum('ni,xyn->xyi', self.Y[self.ind], x) / self.ind.shape[0]**0.5
        return x



# Defining an SO(2) to SO(3) Induction Layer


In [6]:
### defining an induction layer from SO(2) to SO(3)
class Group_Induction_Layer_I(torch.nn.Module):
    
    ''' The Induction Layer is a Linear Layer that takes an SO2 represetation and outputs SO3 representations
    For more details on the induction layer, please read the attached notes!
    
    Class Induction_Layer teturns matrix valued coefficients of spherical harmonics
    :channels: Number of channels in image
    :image_shape: integer, images must be square!!!
    : kmax: maximum degree of so2 harmonics
    :lmax: maximum degree of so3 harmonics
    : rep_in  : input SO2 representation as dict
    : rep_out : output SO3 representation as dict
    
    '''

    def __init__(self, channels:int , image_shape:int , k_max:int , L_max: int, dict_rep_in :dict , dict_rep_out:dict ):
        
        super().__init__()
        self.k_max = k_max
        self.lmax = L_max
        self.channels = channels
        self.image_shape = image_shape
        
        ### set of all convolutional parameters as a list
        self.convs = torch.nn.ParameterList( [] )
    
        ### tensor product represenations as list
        self.tensor_reps = []
        
        ### input and output representations as dict
        self.dict_rep_in = dict_rep_in
        self.dict_rep_out = dict_rep_out
        
        ###input and output space
        self.d_in = rep_ops.compute_SO2_dimension(  self.dict_rep_in )
        self.d_out = rep_ops.compute_SO3_dimension(  self.dict_rep_out )
        
        ### defining SO2 action
        SO2_act = gspaces.Rot2dOnR2(N=-1,maximum_frequency=self.k_max)
        

        ### Defining tensor product input and output features
        for l in range( 0 , L_max + 1 ):
            
            ### output type of so3 rep
            tensor_rep_out = rep_ops.compute_tensor_SO3_l_fold( self.dict_rep_out , l )
            restrict = rep_ops.compute_restriction_SO3( tensor_rep_out )
            rep_out = convert_SO2(  restrict  )            
            
            ### compute the input so2 rep
            tensor_rep_in = rep_ops.compute_tensor_SO2_l_fold( self.dict_rep_in , l )
            rep_in = convert_SO2( tensor_rep_in )
            
            d_in = rep_ops.compute_SO2_dimension( tensor_rep_in )

            d_out_so3 = rep_ops.compute_SO3_dimension(  tensor_rep_out ) 
            d_out_so2 = rep_ops.compute_SO2_dimension(  restrict )
            
            ### print(d_in  , d_out_so2, d_out_so3 )
            
            ###feature types
            self.feat_type_in = nn.FieldType( SO2_act, rep_in  )
            self.feat_type_out = nn.FieldType( SO2_act, rep_out  )
            
            ### convs filters
            conv = nn.R2Conv( self.feat_type_in, self.feat_type_out, kernel_size=self.image_shape)
            self.convs.append( conv )   

        
    def forward(self, x):
        
        ### convert x to a tensor if x is geometric tensor
        x = x.tensor
        
        ### compute the l-th spherical harmonic matrix coeficent
        l_coefs = []
        for l in range( 0 ,  self.lmax  ):
            
            ## get filters
            F_val = self.convs[l].expand_parameters()[0]
            
            ### print( F_val.shape , (l+1)*self.d_out , (l+1)*self.d_in )
            
            ###input split and chunk
            F_val_cat = torch.split( F_val , split_size_or_sections= self.d_out , dim=0)
            g_split = torch.stack( F_val_cat ,dim=0 )
            
            #### output split and chunk
            g_split_full = torch.split( g_split , split_size_or_sections= self.d_in , dim=2)
            g_split_done = torch.stack( g_split_full ,dim=1 )
            
         
            out = torch.einsum('ijklmn , almn-> aijk',  g_split_done , x )
            
            print(out.shape)
            
            l_coefs.append( out )
        
        
      
        
        l_tensor = torch.cat(l_coefs,dim=1)
            
            
        return l_tensor

SyntaxError: positional argument follows keyword argument (3575889141.py, line 103)

In [40]:
### defining an induction layer from SO(2) to SO(3)
class SO3_Induction_Layer(torch.nn.Module):
    
    def __init__(self, input_shape: tuple , sphere_fdim: int ):
        
        super().__init__()
        
        self.sphere_fdim = sphere_fdim
        self.k_max = 8
        self.lmax = lmax + 1
        self.channels = input_shape[0]
        self.image_shape = input_shape[1]
        
        ###dropout probility
        self.p = 0.1
        
        ###input is just many copies of trival
        dict_rep_in = { '0' : self.channels }


        #### the output rep
        ### this can be anything with self.sphere_fdim dimension
        ### consider changing to be less dimension
        dict_rep_out = { '0' : self.sphere_fdim }
        
        ### set of all convolutional parameters as a list
        self.convs = torch.nn.ParameterList( [] )
    
        ### tensor product represenations as list
        self.tensor_reps = []
        
        ### input and output representations as dict
        self.dict_rep_in = dict_rep_in
        self.dict_rep_out = dict_rep_out
        
        ### input and output space dimensions
        self.d_in = rep_ops.compute_SO2_dimension(  self.dict_rep_in )
        self.d_out = rep_ops.compute_SO3_dimension(  self.dict_rep_out )
        
        ### print( self.d_in , self.d_out )
        
        ### defining the SO2 action
        SO2_act = gspaces.Rot2dOnR2(N=-1,maximum_frequency=self.k_max)
        
        ### Defining tensor product input and output features
        for l in range( 0 , self.lmax ):
            
            ### output type of so3 rep
            tensor_rep_out = rep_ops.compute_tensor_SO3_l_fold( self.dict_rep_out , l )
            restrict = rep_ops.compute_restriction_SO3( tensor_rep_out )
            rep_out = convert_SO2(  restrict  )            
            
            ### compute the input so2 rep
            tensor_rep_in = rep_ops.compute_tensor_SO2_l_fold( self.dict_rep_in , l )
            rep_in = convert_SO2( tensor_rep_in )
            
            d_in = rep_ops.compute_SO2_dimension( tensor_rep_in )

            d_out_so3 = rep_ops.compute_SO3_dimension(  tensor_rep_out ) 
            d_out_so2 = rep_ops.compute_SO2_dimension(  restrict )
            
            print(d_in  , d_out_so2, d_out_so3 )
            
            ###feature types
            self.feat_type_in = nn.FieldType( SO2_act, rep_in  )
            self.feat_type_out = nn.FieldType( SO2_act, rep_out  )
            
            ### convs filters
            conv = nn.R2Conv( self.feat_type_in, self.feat_type_out, kernel_size=self.image_shape)
            self.convs.append( conv )   

        
    def forward(self, x):
        
        ### convert x to a tensor if x is geometric tensor
        x = x.tensor
        
        ### compute the l-th spherical harmonic matrix coeficent
        l_coefs = []
        for l in range( 0 ,  self.lmax  ):
            
            ## get filters
            F_val = self.convs[l].expand_parameters()[0]
            
            ### print( F_val.shape , (l+1)*self.d_out , (l+1)*self.d_in )
            
            ###input split and chunk
            F_val_cat = torch.split( F_val , split_size_or_sections= self.d_out , dim=0)
            g_split = torch.stack( F_val_cat ,dim=0 )
            
            #### output split and chunk
            g_split_full = torch.split( g_split , split_size_or_sections= self.d_in , dim=2)
            g_split_done = torch.stack( g_split_full ,dim=1 )
            
         
            out = torch.einsum('ijklmn , almn-> aijk',  g_split_done , x )
            
            print(out.shape)
            
            l_coefs.append( out )
        
        
      
        
        l_tensor = torch.cat(l_coefs,dim=1)
            
            
        return l_tensor

In [None]:
rep_in = [ SO2_act.irrep(0) ]
feat_type_in = nn.FieldType( SO2_act, rep_in  )

 # ### specifiy the hiden SO2 layer muplicities
hidden_mulplicities_SO2 = { '0' : 1 , '1': 1 , '2':1 , '3':1   }
hidden_so2 = convert_SO2( hidden_mulplicities_SO2 )

kmax = 20
lmax = 6
kernel_size = 3
channels_in = 210
image_size = 10

batch_size = 5
### SO2 convolution layer
SO2_conv_1 = SO2_Convolution_Layer( rep_in = rep_in, rep_out = hidden_so2 ,k_max = kmax,  kernel_size= kernel_size )

# ### the output mupliciteis of the induced SO3 layer
mulplicities_SO3 = { '0' :1 , '1' : 1 , '2' : 1   }

# ##### the induction representation layer, 
# ### compute the number of output channels of hidden rep
## channels_in = rep_ops.compute_SO2_dimension( hidden_mulplicities_SO2  )

### first induction layer method
## induce_I = Group_Induction_Layer_I( channels = channels_in, image_shape=(image_size - kernel_size +1) , k_max =kmax, L_max=lmax , dict_rep_in = hidden_mulplicities_SO2 , dict_rep_out = mulplicities_SO3 )
input_shape = [channels_in,image_size]
so3_induce = SO3_Induction_Layer( input_shape = input_shape , sphere_fdim = 15  )

print(so3_induce)


210 15 15
630 45 45
1050 75 75
1470 105 105


In [39]:
start_time = time.time()
x = torch.rand(batch_size,1,image_size,image_size) #z = induce_I( y )
stop_time = time.time()

print(x)

print( stop_time-start_time )


tensor([[[[5.2040e-03, 9.8159e-01, 5.5502e-01, 6.5770e-01, 4.9620e-01,
           8.5580e-01, 3.4187e-02, 3.5910e-01, 8.8188e-01, 5.6823e-01],
          [2.7988e-01, 8.7438e-01, 1.1401e-01, 3.2496e-01, 2.2658e-02,
           7.9548e-01, 5.5532e-01, 9.8161e-01, 1.9070e-02, 9.6225e-01],
          [6.7182e-01, 4.2635e-01, 2.7867e-01, 8.3270e-01, 1.3461e-01,
           9.0029e-01, 9.0362e-01, 5.0239e-01, 6.7894e-02, 4.5175e-01],
          [8.8160e-01, 1.9020e-02, 6.2831e-01, 6.6655e-01, 7.2490e-02,
           4.5307e-01, 7.7476e-02, 7.9424e-01, 4.9250e-01, 3.3369e-01],
          [5.2850e-01, 6.9901e-01, 7.9288e-01, 6.4836e-01, 3.4757e-01,
           8.0347e-01, 8.4256e-01, 5.5067e-01, 3.8257e-02, 3.1057e-01],
          [3.0751e-01, 1.4018e-01, 3.2517e-01, 7.1919e-01, 9.8683e-02,
           7.9049e-01, 9.2760e-02, 6.0196e-02, 1.6074e-01, 2.5648e-01],
          [2.3387e-01, 8.7683e-01, 7.5615e-01, 8.2100e-01, 6.5694e-01,
           2.0350e-01, 5.1654e-02, 7.5868e-01, 4.9637e-01, 3.2813e-01],

# Defining an SO(2) to Sphere Induction Layer

In [24]:
### The naive method: This is wayyyyy too slow
### defining an induction layer from SO(2) to SO(3)
class Induction_Layer_I(torch.nn.Module):
    
    ''' The Induction Layer is a Linear Layer that takes an SO2 represetation and outputs SO3 representations
    For more details on the induction layer, please read the attached notes!
    
    Class Induction_Layer teturns matrix valued coefficients of spherical harmonics
    :channels: Number of channels in image
    :image_shape: integer, images must be square!!!
    : kmax: maximum degree of so2 harmonics
    :lmax: maximum degree of so3 harmonics
    : rep_in  : input SO2 representation as dict
    : rep_out : output SO3 representation as dict
    
    '''

    def __init__(self, channels:int , image_shape:int , k_max:int , L_max: int, dict_rep_in :dict , dict_rep_out:dict ):
        
        super().__init__()
        self.k_max = k_max
        self.lmax = L_max
        self.channels = channels
        self.image_shape = image_shape
        
        ### set of all convolutional parameters as a list
        self.convs = torch.nn.ParameterList( [] )
    
        ### tensor product represenations as list
        self.tensor_reps = []
        
        ### input and output representations as dict
        self.dict_rep_in = dict_rep_in
        self.dict_rep_out = dict_rep_out
        
        ### defining SO2 action
        SO2_act = gspaces.Rot2dOnR2(N=-1,maximum_frequency=self.k_max)
        
        ### compute restriction of SO(3) output SO(2) representation
        restrict = rep_ops.compute_restriction_SO3( self.dict_rep_out )
        self.rep_out = convert_SO2(  restrict  )
        
        ### output feature types
        self.feat_type_out = nn.FieldType( SO2_act, self.rep_out  )
        
        ### Defining tensor product features
        for l in range( 0 , L_max + 1 ):
            
            tensor_rep_in = rep_ops.compute_tensor_SO2_l_fold( self.dict_rep_in , l )
            
            
            #### compute the
            rep_in = convert_SO2( tensor_rep_in )
            #print( tensor_rep_in )
            #### compute the dimension of an so2 rep
            d_v = rep_ops.compute_tensor_SO2_dimension( tensor_rep_in  )
            
            self.feat_type_in = nn.FieldType( SO2_act, rep_in  )
                
            conv = nn.R2Conv( self.feat_type_in, self.feat_type_out, kernel_size=self.image_shape)
            self.convs.append( conv )   

            
    def forward(self, x):
        
        ### convert x to a tensor if x is geometric tensor
        x = x.tensor
        
        ### compute the l-th spherical harmonic matrix coeficent
        l_coefs = []
        for l in range( 0 ,  self.lmax  ):
            
            ## get filters
            F_val = self.convs[l].expand_parameters()[0]
            
            l_k_coefs = []
            for k in range(2*l+1):
                
                ### reformat to be same size
                F_k = F_val[:,self.channels*k:self.channels*k + self.channels*1,:,:]
                out = torch.einsum('ijkl , ajkl-> ai',  F_k , x )
                l_k_coefs.append(out)

            
            l_k_tensor = torch.stack( l_k_coefs )
            #l_k_tensor = torch.einsum('ijk -> jki ', l_k_tensor )
            
            l_coefs.append( l_k_tensor )
        
        l_tensor = torch.cat(l_coefs,dim=0)
        l_tensor = torch.einsum('ijk -> jki ', l_tensor )
        
        return l_tensor



In [25]:
### faster version -- still way to slow
### defining an induction layer from SO(2) to SO(3)
### aka the Ondrej method
class Induction_Layer_II( torch.nn.Module ):
    
    ''' The Induction Layer is a Linear Layer that takes an SO2 represetation and outputs SO3 representations
    For more details on the induction layer, please read the attached notes!
    
    Class Induction_Layer teturns matrix valued coefficients of spherical harmonics
    :channels: Number of channels in image
    :image_shape: integer, images must be square!!!
    : kmax: maximum degree of so2 harmonics
    :lmax: maximum degree of so3 harmonics
    : rep_in  : input SO2 representation as dict
    : rep_out : output SO3 representation as dict
    
    '''

    def __init__(self, channels:int , image_shape:int , k_max:int , L_max: int, dict_rep_in :dict , dict_rep_out:dict ):
        
        super().__init__()
        self.k_max = k_max
        self.lmax = L_max
        self.channels = channels
        self.image_shape = image_shape
        
        ### set of all convolutional parameters as a list
        self.convs = torch.nn.ParameterList( [] )
    
        ### tensor product represenations as list
        self.tensor_reps = []
        
        ### input and output representations as dict
        self.dict_rep_in = dict_rep_in
        self.dict_rep_out = dict_rep_out
        
        ### defining SO2 action
        SO2_act = gspaces.Rot2dOnR2(N=-1,maximum_frequency=self.k_max)
        
        ### compute restriction of SO(3) output SO(2) representation
        restrict = rep_ops.compute_restriction_SO3( self.dict_rep_out )
        self.rep_out = convert_SO2(  restrict  )
        
        ### output feature types
        self.feat_type_out = nn.FieldType( SO2_act, self.rep_out  )
        
        
        ### Defining tensor product features
        convs_list = []
        for l in range( 0 , L_max+1 ):
            
            tensor_rep_in = rep_ops.compute_tensor_SO2_l_fold( self.dict_rep_in , l )
            
            rep_in = convert_SO2( tensor_rep_in )
            self.d_v = rep_ops.compute_tensor_SO2_dimension( tensor_rep_in  )
            self.feat_type_in = nn.FieldType( SO2_act, rep_in  ) 
            conv = nn.R2Conv( self.feat_type_in, self.feat_type_out, kernel_size=self.image_shape)
            self.convs.append( conv )   
            
        #self.convs = convs_list
            
    def forward(self, x):
        
        ### convert x to a tensor if x is geometric tensor
        x = x.tensor
        
        ### There should be faster way to do this...
        
        ### compute the l-th spherical harmonic matrix coeficent
        l_coefs = []
        F_val_list = []
        for l in range( 0 ,  self.lmax  ):
            ## get filters
            F_val = self.convs[l].expand_parameters()[0]
            F_val_list.append(F_val)
            #print("F_val.shape", F_val.shape)
        F_val_cat = torch.concat(F_val_list, dim=1)
        #print("F_val_cat.shape", F_val_cat.shape)
                
        ###split and stack
        g = torch.split( F_val_cat , split_size_or_sections= self.channels , dim=1)
        #print("g[0].shape", g[0].shape)
        g_tens = torch.stack( g ,dim=0 )
        #print("g_tens.shape", g_tens.shape)

        ###comput sum
        # torch.Size([23, 36, 5, 100, 100])
        test_v = torch.einsum('ijklm , aklm-> iaj',  g_tens , x )
        l_tensor = test_v
                    
        ### cat
        # l_tensor = torch.cat(l_coefs,dim=0)    
        #print("l_tensor", l_tensor.shape)
        l_tensor = l_tensor.swapaxes(0,1).swapaxes(1,2)
        
        return l_tensor

In [26]:
### Different approch, done using just matrix multiplications
### This is still slow, but on par with the projection method
### defining an induction layer from SO(2) to SO(3)
class Induction_Layer_III( torch.nn.Module ):
    
    ''' The Induction Layer is a Linear Layer that takes an SO2 represetation and outputs SO3 representations
    For more details on the induction layer, please read the attached notes!
    
    Class Induction_Layer teturns matrix valued coefficients of spherical harmonics
    :channels: Number of channels in image
    :image_shape: integer, images must be square!!!
    : kmax: maximum degree of so2 harmonics
    :lmax: maximum degree of so3 harmonics
    : rep_in  : input SO2 representation as dict
    : rep_out : output SO3 representation as dict
    
    '''

    def __init__(self, channels:int , image_shape:int , k_max:int , L_max: int, dict_rep_in :dict , dict_rep_out:dict ):
        
        super().__init__()
        self.k_max = k_max
        self.lmax = L_max
        self.channels = channels
        self.image_shape = image_shape
        
        ### tensor product represenations as list
        self.tensor_reps = []
        
        ### input and output representations as dict
        self.dict_rep_in = dict_rep_in
        self.dict_rep_out = dict_rep_out
        
        ### defining SO2 action
        SO2_act = gspaces.Rot2dOnR2(N=-1,maximum_frequency=self.k_max)
        
        ### compute restriction of SO(3) output SO(2) representation
        restrict = rep_ops.compute_restriction_SO3( self.dict_rep_out )
        self.rep_out = convert_SO2(  restrict  )
        
        ### output feature types
        self.feat_type_out = nn.FieldType( SO2_act, self.rep_out  )
        
        
        ### compute direct sum representation
        total_rep = []
        for l in range( 0 , L_max ):
            
            tensor_rep_in = rep_ops.compute_tensor_SO2_l_fold( self.dict_rep_in , l )
            
            rep_in = convert_SO2( tensor_rep_in )

            
            total_rep = total_rep + rep_in
            
        feat_type_in = nn.FieldType( SO2_act, total_rep  )
            
        self.conv = nn.R2Conv( feat_type_in , self.feat_type_out , kernel_size=self.image_shape)        
        
        
            
    def forward(self, x):
                
        F = self.conv.expand_parameters()[0]
        F_val_cat = torch.split( F , split_size_or_sections= self.channels , dim=1)
        g_split = torch.stack( F_val_cat ,dim=0 )
            
        ###now contract
        y = torch.einsum('ijklm , aklm -> aji' ,  g_split , x.tensor )
        
        return y

In [27]:
rep_in = [ SO2_act.irrep(0) ]
feat_type_in = nn.FieldType( SO2_act, rep_in  )

 # ### specifiy the hiden SO2 layer muplicities
hidden_mulplicities_SO2 = { '0' : 1 , '1': 1 , '2':1 , '3':1 , '4':1   }
hidden_so2 = convert_SO2( hidden_mulplicities_SO2 )

kmax = 10
lmax = 6
kernel_size = 20
image_size = 200

batch_size = 50

### SO2 convolution layer
SO2_conv_1 = SO2_Convolution_Layer( rep_in = rep_in, rep_out = hidden_so2 ,k_max = kmax,  kernel_size= kernel_size )

# ### the output mupliciteis of the induced SO3 layer
mulplicities_SO3 = { '0' :1 , '1' : 1 , '2' : 1 , '3':1   }

# ##### the induction representation layer, 
# ### compute the number of output channels of hidden rep
channels_in = rep_ops.compute_SO2_dimension( hidden_mulplicities_SO2  )

### first induction layer method
induce_I = Induction_Layer_I( channels = channels_in, image_shape=(image_size - kernel_size +1) , k_max =kmax, L_max=lmax , dict_rep_in = hidden_mulplicities_SO2 , dict_rep_out = mulplicities_SO3 )

### second induction layer method
induce_II = Induction_Layer_II( channels = channels_in, image_shape=(image_size - kernel_size +1) , k_max =kmax, L_max=lmax , dict_rep_in = hidden_mulplicities_SO2 , dict_rep_out = mulplicities_SO3 )

### second induction layer method
induce_III = Induction_Layer_III( channels = channels_in, image_shape=(image_size - kernel_size +1) , k_max =kmax, L_max=lmax , dict_rep_in = hidden_mulplicities_SO2 , dict_rep_out = mulplicities_SO3 )



### compare with orthographic projection
orthographic_proj = Image2SphereProjector( fmap_shape=(9,image_size -kernel_size+1,image_size-kernel_size+1), sphere_fdim= 50, lmax=lmax,
               coverage = 0.9,
               sigma = 0.2,
               max_beta = np.radians(90),
               taper_beta = np.radians(75),
               rec_level = 2,
               n_subset = 20 )



  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


# Runtime Comparisions: Orthographic Projection vs Induced Map
In order for our induced method to train in a reasonable amount of time, we need to make sure that the evalutation time of the induced mapping is on par with the evaluation time of the orthgraphic projection method. Intitivly this should be possible because both methods are calculating spherical harmonic coeficents. However, e2nn is built to do efficent convolutions, not nessercarly integrals, so some creativity is required to write efficent code

In [None]:
### Orthographic Projection timing
start_time = time.time()
x = torch.rand(batch_size,1,image_size,image_size)
x = nn.GeometricTensor( x , feat_type_in  )
y = SO2_conv_1.forward(x)
z = orthographic_proj( y.tensor )
stop_time = time.time()


total_time =  stop_time - start_time
print("Orthographic Projection Timing:" , total_time)


### Induced method I timing
start_time = time.time()
x = torch.rand(batch_size,1,image_size,image_size) 
x = nn.GeometricTensor( x , feat_type_in  )
y = SO2_conv_1.forward(x)
z = induce_I( y )
stop_time = time.time()


total_time =  stop_time - start_time

print("Induced Map Method_I:" , total_time)

### Induced method II timing
start_time = time.time()
x = torch.rand(batch_size,1,image_size,image_size) 
x = nn.GeometricTensor( x , feat_type_in  )
y = SO2_conv_1.forward(x)
z = induce_II( y )
stop_time = time.time()


total_time =  stop_time - start_time
print("Induced Map Method_II:" , total_time)



### Induced method III timing
start_time = time.time()
x = torch.rand(batch_size,1,image_size,image_size) 
x = nn.GeometricTensor( x , feat_type_in  )
y = SO2_conv_1.forward(x)
z = induce_III( y )
stop_time = time.time()


total_time =  stop_time - start_time
print("Induced Map Method_III:" , total_time)





# Spherical Convolution



In [28]:
### Ask david if there is max beta that is optimal
def s2_healpix_grid(rec_level: int=0, max_beta: float=np.pi/6):
    """Returns healpix grid up to a max_beta
    """
    n_side = 2**rec_level
    npix = hp.nside2npix(n_side)
    m = hp.query_disc(nside=n_side, vec=(0,0,1), radius=max_beta)
    beta, alpha = hp.pix2ang(n_side, m)
    alpha = torch.from_numpy(alpha)
    beta = torch.from_numpy(beta)
    return torch.stack((alpha, beta)).float()

def flat_wigner(lmax, alpha, beta, gamma):
    return torch.cat([ (2 * l + 1) ** 0.5 * o3.wigner_D(l, alpha, beta, gamma).flatten(-2) for l in range(lmax + 1) ], dim=-1)

### this should be changed,
### or just set output to be equal to hidden so3 layer
def s2_irreps(lmax):
    return o3.Irreps([(1, (l, 1)) for l in range(lmax + 1)])

def so3_irreps(lmax):
    return o3.Irreps([(2 * l + 1, (l, 1)) for l in range(lmax + 1)])



### defining convolution over sphere
### just make sure that input so3 rep matches output so3 rep
class S2Conv(torch.nn.Module):

    def __init__(self, f_in: int, f_out: int, lmax: int , kernel_grid: tuple):
        super().__init__()

        # filter weight parametrized over spatial grid on S2
        self.register_parameter(
          "w", torch.nn.Parameter(torch.randn(f_in, f_out, kernel_grid.shape[1]))
        )  # [f_in, f_out, n_s2_pts]

        # linear projection to convert filter weights to fourier domain
        self.register_buffer(
          "Y", o3.spherical_harmonics_alpha_beta(range(lmax + 1), *kernel_grid, normalization="component")
        )  # [n_s2_pts, (2*lmax+1)**2]


        # defines group convolution using appropriate irreps
        self.lin = o3.Linear( s2_irreps(lmax) , so3_irreps(lmax) ,  f_in=f_in, f_out=f_out, internal_weights=False)

    def forward(self, x):
        psi = torch.einsum("ni,xyn->xyi", self.Y, self.w) / self.Y.shape[0] ** 0.5
        return self.lin(x , weight=psi)



# SO(3) Convolution

In [29]:
def so3_healpix_grid(rec_level: int=3):
    """Returns healpix grid over so3 of equally spaced rotations
   
    https://github.com/google-research/google-research/blob/4808a726f4b126ea38d49cdd152a6bb5d42efdf0/implicit_pdf/models.py#L272
    alpha: 0-2pi around Y
    beta: 0-pi around X
    gamma: 0-2pi around Y
    rec_level | num_points | bin width (deg)
    ----------------------------------------
         0    |         72 |    60
         1    |        576 |    30
         2    |       4608 |    15
         3    |      36864 |    7.5
         4    |     294912 |    3.75
         5    |    2359296 |    1.875
         
    :return: tensor of shape (3, npix)
    """
    n_side = 2**rec_level
    npix = hp.nside2npix(n_side)
    beta, alpha = hp.pix2ang(n_side, torch.arange(npix))
    gamma = torch.linspace(0, 2*np.pi, 6*n_side + 1)[:-1]

    alpha = alpha.repeat(len(gamma))
    beta = beta.repeat(len(gamma))
    gamma = torch.repeat_interleave(gamma, npix)
    return torch.stack((alpha, beta, gamma)).float()

###convolutation over so3
### maybe faster way to do this
class SO3Conv(torch.nn.Module):
    def __init__(self, f_in: int, f_out: int, lmax: int, kernel_grid: tuple):
        super().__init__()

        # filter weight parametrized over spatial grid on SO3
        self.register_parameter(
          "w", torch.nn.Parameter(torch.randn(f_in, f_out, kernel_grid.shape[1]))
        )  # [f_in, f_out, n_so3_pts]

        # wigner D matrices used to project spatial signal to irreps of SO(3)
        self.register_buffer("D", flat_wigner(lmax, *kernel_grid))  # [n_so3_pts, sum_l^L (2*l+1)**2]

        # defines group convolution using appropriate irreps
        self.lin = o3.Linear(so3_irreps(lmax), so3_irreps(lmax), f_in=f_in, f_out=f_out, internal_weights=False)

    def forward(self, x):
        '''Perform SO3 group convolution to produce signal over irreps of SO(3).
        First project filter into fourier domain then perform convolution

        :x: tensor of shape (B, f_in, sum_l^L (2*l+1)**2), signal over SO3 irreps
        :return: tensor of shape (B, f_out, sum_l^L (2*l+1)**2)
        '''
        psi = torch.einsum("ni,xyn->xyi", self.D, self.w) / self.D.shape[0] ** 0.5
        return self.lin(x, weight=psi)



# Loss functions

In [None]:
def compute_trace(rotA, rotB):
    
    '''
    rotA, rotB are tensors of shape (*,3,3)
    returns Tr(rotA, rotB.T)
    '''
    #rotA = rotA.type( torch.long )
    #rotB = rotB.type( torch.long )
    prod = torch.matmul(rotA, rotB.transpose(-1, -2))
    trace = prod.diagonal(dim1=-1, dim2=-2).sum(-1)
    return trace

def rotation_error(rotA, rotB):
    '''
    rotA, rotB are tensors of shape (*,3,3)
    returns rotation error in radians, tensor of shape (*)
    '''
    #rotA = rotA.type(torch.long)
    #rotB = rotB.type(torch.long)
    trace = compute_trace(rotA, rotB)
    return torch.arccos(torch.clamp( (trace - 1)/2, -1, 1))

def nearest_rotmat(src, target):
    
    '''return index of target that is nearest to each element in src
    uses negative trace of the dot product to avoid arccos operation
    :src: tensor of shape (B, 3, 3)
    :target: tensor of shape (*, 3, 3)
    '''
    trace = compute_trace(src.unsqueeze(1), target.unsqueeze(0))
   
    return torch.max(trace, dim=1)[1]


# Defining the Standard I2S Network

In [None]:
### I2S network
class I2S(torch.nn.Module):
    
    ### Instantiate I2S-style network for predicting distributions over SO(3) from
    ### predictions made on single image using an induction layer 
    
    def __init__(self, lmax=20 , kmax = 50 , image_size = 200 , so2_kernel_size = 25 ):
        
        super().__init__()
        self.lmax = lmax
        self.kmax = kmax
        self.image_size = image_size
        self.kernel_size = so2_kernel_size

        ### no image encoder, can add this later
        ### self.encoder = ImageEncoder()

        ### defining the SO2 action
        SO2_act = gspaces.Rot2dOnR2(N=-1,maximum_frequency=self.kmax)

        ### suppose that input is trival so2 rep
        rep_in = [ SO2_act.irrep(0) ]

        # ### specifiy the hiden SO2 layer muplicities
        hidden_mulplicities_SO2 = { '0' : 1 , '1': 1 , '2':1 , '3':1  , '4':1   }
        hidden_so2 = convert_SO2( hidden_mulplicities_SO2 )
        
        ### SO2 convolution layer
        self.SO2_conv_1 = SO2_Convolution_Layer( rep_in = rep_in, rep_out = hidden_so2 ,k_max = self.kmax,  kernel_size=self.kernel_size )

        ### the output mupliciteis of the induced SO3 layer
        mulplicities_SO3 = { '0' :1 , '1' :1 , '2' : 1 , '3': 1 , '4':1  }

        ##### the induction representation layer, 
        ### compute the number of output channels of hidden rep
        channels_in = rep_ops.compute_SO2_dimension( hidden_mulplicities_SO2  )
    
        self.proj = Image2SphereProjector( fmap_shape=(channels_in ,image_size -kernel_size+1,image_size-kernel_size+1), sphere_fdim= 50, lmax=lmax,
               coverage = 0.9,
               sigma = 0.2,
               max_beta = np.radians(90),
               taper_beta = np.radians(75),
               rec_level = 2,
               n_subset = 20 )
        
        
        ### output format is: batch, number of output channels, number of input channels
        ### these are all in form of 
        s2_kernel_grid = s2_healpix_grid(max_beta=np.inf, rec_level=1)

        ### THIS IS L_MAX - 1 !!! Need to standardize notations
        ### compute the dimension of s2 input features
        f_in = rep_ops.compute_SO3_dimension( mulplicities_SO3 )
        self.s2_conv = S2Conv( f_in = f_in , f_out= 105 , lmax=self.lmax-1 , kernel_grid = s2_kernel_grid )

        #### also L_max - 1 !!! Need to standardize notations
        so3_kernel_grid = so3_healpix_grid(rec_level=3)
        ### output is one dimensional so can use logits 
        self.so3_conv = SO3Conv( f_in = 105 , f_out=1 , lmax=self.lmax-1 , kernel_grid = so3_kernel_grid )
        self.so3_act = e3nn.nn.SO3Activation( self.lmax-1 , self.lmax-1 , act=torch.relu, resolution=20)
        
        output_xyx = so3_healpix_grid(rec_level=2)
        self.register_buffer( "output_wigners", flat_wigner( self.lmax - 1 , *output_xyx).transpose(0,1) )
        self.register_buffer( "output_rotmats", o3.angles_to_matrix(*output_xyx) )


    def forward(self, x):
        
        ###'''Returns so3 irreps
        ###:x: the input image, tensor of shape (B, 1, image_size, image_size)
        ## x must be a geometric tensor

        x = self.SO2_conv_1( x )
        x = self.proj( x.tensor )
        x = self.s2_conv( x )
        x = self.so3_act( x )
        x = self.so3_conv( x )
    
        return x
  
    def compute_loss(self, img, gt_rot):
        ###'''Compute cross entropy loss using ground truth rotation, the correct label
        ###is the nearest rotation in the spatial grid to the ground truth rotation

        ### :img: float tensor of shape (B, 3, 224, 224)
        ### :gt_rotation: valid rotation matrices, tensor of shape (B, 3, 3)
        
        ### run image through network
        x = self.forward(img)
        
        ### make sure output is long type tensor
        # x = x.tensor
        
        grid_signal = torch.matmul(x, self.output_wigners ).squeeze(1)
        rotmats = self.output_rotmats


        # find nearest grid point to ground truth rotation matrix
        rot_id = nearest_rotmat( gt_rot , rotmats )
            
        loss = torch.nn.CrossEntropyLoss()( grid_signal.type( torch.float )  , rot_id  )
        
        with torch.no_grad():
            pred_id = grid_signal.max(dim=1)[1]
            pred_rotmat = rotmats[pred_id]
            acc = rotation_error(gt_rot, pred_rotmat)

        return loss, acc.cpu().numpy()
   
    @torch.no_grad()
    def compute_probabilities(self, img, wigners):
        x = self.forward(img)
        logits = torch.matmul(x, wigners).squeeze(1)
        return torch.nn.Softmax(dim=1)(logits)





# Defining the Induced Network

In [None]:
### Induced_I2S network
class Induced_I2S(torch.nn.Module):
    
    ### Instantiate I2S-style network for predicting distributions over SO(3) from
    ### predictions made on single image using an induction layer 
    
    def __init__(self, lmax=20 , kmax = 50 , image_size = 200 , so2_kernel_size = 25 ):
        
        super().__init__()
        self.lmax = lmax
        self.kmax = kmax
        self.image_size = image_size
        self.kernel_size = so2_kernel_size

        ### no image encoder, can add this later
        ### self.encoder = ImageEncoder()

        ### defining the SO2 action
        SO2_act = gspaces.Rot2dOnR2(N=-1,maximum_frequency=self.kmax)

        ### suppose that input is trival so2 rep
        rep_in = [ SO2_act.irrep(0) ]

        # ### specifiy the hiden SO2 layer muplicities
        hidden_mulplicities_SO2 = { '0' : 1 , '1': 1 , '2':1 , '3':1  , '4':1   }
        hidden_so2 = convert_SO2( hidden_mulplicities_SO2 )
        
        ### SO2 convolution layer
        self.SO2_conv_1 = SO2_Convolution_Layer( rep_in = rep_in, rep_out = hidden_so2 ,k_max = self.kmax,  kernel_size=self.kernel_size )

        ### the output mupliciteis of the induced SO3 layer
        mulplicities_SO3 = { '0' :1 , '1' :1 , '2' : 1 , '3': 1 , '4':1  }

        ##### the induction representation layer, 
        ### compute the number of output channels of hidden rep
        channels_in = rep_ops.compute_SO2_dimension( hidden_mulplicities_SO2  )
        self.induce = Induction_Layer_II( channels = channels_in, image_shape=(self.image_size - self.kernel_size +1) , k_max =self.kmax, L_max=self.lmax , dict_rep_in = hidden_mulplicities_SO2 , dict_rep_out = mulplicities_SO3 )

        ### output format is: batch, number of output channels, number of input channels
        ### these are all in form of 
        s2_kernel_grid = s2_healpix_grid(max_beta=np.inf, rec_level=1)

        ### THIS IS L_MAX - 1 !!! Need to standardize notations
        ### compute the dimension of s2 input features
        f_in = rep_ops.compute_SO3_dimension( mulplicities_SO3 )
        self.s2_conv = S2Conv( f_in = f_in , f_out= 105 , lmax=self.lmax-1 , kernel_grid = s2_kernel_grid )

        #### also L_max - 1 !!! Need to standardize notations
        so3_kernel_grid = so3_healpix_grid(rec_level=3)
        ### output is one dimensional so can use logits 
        self.so3_conv = SO3Conv( f_in = 105 , f_out=1 , lmax=self.lmax-1 , kernel_grid = so3_kernel_grid )
        self.so3_act = e3nn.nn.SO3Activation( self.lmax-1 , self.lmax-1 , act=torch.relu, resolution=20)
        
        output_xyx = so3_healpix_grid(rec_level=2)
        self.register_buffer( "output_wigners", flat_wigner( self.lmax - 1 , *output_xyx).transpose(0,1) )
        self.register_buffer( "output_rotmats", o3.angles_to_matrix(*output_xyx) )


    def forward(self, x):
        
        ###'''Returns so3 irreps
        ###:x: the input image, tensor of shape (B, 1, image_size, image_size)
        ## x must be a geometric tensor

        x = self.SO2_conv_1(x)
        x = self.induce( x )
        x = self.s2_conv( x )
        x = self.so3_act( x )
        x = self.so3_conv( x )
    
        return x
  
    def compute_loss(self, img, gt_rot):
        ###'''Compute cross entropy loss using ground truth rotation, the correct label
        ###is the nearest rotation in the spatial grid to the ground truth rotation

        ### :img: float tensor of shape (B, 3, 224, 224)
        ### :gt_rotation: valid rotation matrices, tensor of shape (B, 3, 3)
        
        ### run image through network
        x = self.forward(img)
        
        ### make sure output is long type tensor
        # x = x.tensor
        
        grid_signal = torch.matmul(x, self.output_wigners ).squeeze(1)
        rotmats = self.output_rotmats


        # find nearest grid point to ground truth rotation matrix
        rot_id = nearest_rotmat( gt_rot , rotmats )
            
        loss = torch.nn.CrossEntropyLoss()( grid_signal.type( torch.float )  , rot_id  )
        
        with torch.no_grad():
            pred_id = grid_signal.max(dim=1)[1]
            pred_rotmat = rotmats[pred_id]
            acc = rotation_error(gt_rot, pred_rotmat)

        return loss, acc.cpu().numpy()
   
    @torch.no_grad()
    def compute_probabilities(self, img, wigners):
        x = self.forward(img)
        logits = torch.matmul(x, wigners).squeeze(1)
        return torch.nn.Softmax(dim=1)(logits)



In [None]:
lmax=4
induced_arch = Induced_I2S( lmax=lmax , kmax = 50 , image_size = 120 , so2_kernel_size=65 )

output_xyx = so3_healpix_grid(rec_level=3) # 37K points
output_wigners = flat_wigner( lmax - 1 , *output_xyx).transpose(0, 1)
output_rotmats = o3.angles_to_matrix(*output_xyx)

In [None]:
standard_i2s = I2S( lmax=lmax , kmax = 50 , image_size = 120 , so2_kernel_size = 65 )

In [None]:
# print( induced_arch )
# print( standard_i2s )

In [None]:
import matplotlib.pyplot as plt
def plot_so3_distribution(probs: torch.Tensor,
                          rots: torch.Tensor,
                          gt_rotation=None,
                          fig=None,
                          ax=None,
                          display_threshold_probability=0.000005,
                          show_color_wheel: bool=True,
                          canonical_rotation=torch.eye(3),
                         ):
    '''
    Taken from https://github.com/google-research/google-research/blob/master/implicit_pdf/evaluation.py
    '''
    cmap = plt.cm.hsv

    def _show_single_marker(ax, rotation, marker, edgecolors=True, facecolors=False):
        alpha, beta, gamma = o3.matrix_to_angles(rotation)
        color = cmap(0.5 + gamma.repeat(2) / 2. / np.pi)[-1]
        ax.scatter(alpha, beta-np.pi/2, s=2000, edgecolors=color, facecolors='none', marker=marker, linewidth=5)
        ax.scatter(alpha, beta-np.pi/2, s=1500, edgecolors='k', facecolors='none', marker=marker, linewidth=2)
        ax.scatter(alpha, beta-np.pi/2, s=2500, edgecolors='k', facecolors='none', marker=marker, linewidth=2)

    if ax is None:
        fig = plt.figure(figsize=(8, 4), dpi=200)
        fig.subplots_adjust(0.01, 0.08, 0.90, 0.95)
        ax = fig.add_subplot(111, projection='mollweide')

    rots = rots @ canonical_rotation
    scatterpoint_scaling = 3e3
    alpha, beta, gamma = o3.matrix_to_angles(rots)

    # offset alpha and beta so different gammas are visible
    R = 0.02
    alpha += R * np.cos(gamma)
    beta += R * np.sin(gamma)

    which_to_display = (probs > display_threshold_probability)

    # Display the distribution
    ax.scatter(alpha[which_to_display],
               beta[which_to_display]-np.pi/2,
               s=scatterpoint_scaling * probs[which_to_display],
               c=cmap(0.5 + gamma[which_to_display] / 2. / np.pi))
    if gt_rotation is not None:
        if len(gt_rotation.shape) == 2:
            gt_rotation = gt_rotation.unsqueeze(0)
        gt_rotation = gt_rotation @ canonical_rotation
        _show_single_marker(ax, gt_rotation, 'o')
    ax.grid()
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    if show_color_wheel:
        # Add a color wheel showing the tilt angle to color conversion.
        ax = fig.add_axes([0.86, 0.17, 0.12, 0.12], projection='polar')
        theta = np.linspace(-3 * np.pi / 2, np.pi / 2, 200)
        radii = np.linspace(0.4, 0.5, 2)
        _, theta_grid = np.meshgrid(radii, theta)
        colormap_val = 0.5 + theta_grid / np.pi / 2.
        ax.pcolormesh(theta, radii, colormap_val.T, cmap=cmap)
        ax.set_yticklabels([])
        ax.set_xticklabels([r'90$\degree$', None,
                            r'180$\degree$', None,
                            r'270$\degree$', None,
                            r'0$\degree$'], fontsize=14)
        ax.spines['polar'].set_visible(False)
        plt.text(0.5, 0.5, 'Tilt', fontsize=14,
                 horizontalalignment='center',
                 verticalalignment='center', transform=ax.transAxes)

    plt.show()

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, labels_file, img_dir ):
        
        with open( labels_file , 'rb') as handle:
            self.img_labels = pickle.load(handle)
            
        self.img_dir = img_dir 
        
    def __len__(self):
        return len( self.img_labels.keys()  ) 

    def __getitem__(self, idx):
        
        file_path = self.img_dir +'img_num:'+str(idx) +'.npy'
        image = np.load( file_path, allow_pickle=True )
        image = np.float32(image)
        image = torch.from_numpy(image)

        label = self.img_labels[ str(idx) ]
        label = np.float32(label)
        label = torch.from_numpy(label)
        
        

        return image, label

        
        
img_dir = './data/image_data/'
labels_file = './data/image_data/rotations.pickle'

data_set = CustomImageDataset( labels_file = labels_file,  img_dir =img_dir )
train_dataloader = DataLoader( data_set, batch_size=1, shuffle=True)


# Training the Induced_I2S Model


In [None]:
optimizer = torch.optim.SGD( induced_arch.parameters(), lr=0.003, momentum=0.03  )

rep_in = [ SO2_act.irrep(0) ]
feat_type_in = nn.FieldType( SO2_act, rep_in  )

num_epoch = 3
for epoch in range(num_epoch):
    for img, label in train_dataloader:
    
        print(img.shape)

        img = torch.unsqueeze(img, 1)
        
        optimizer.zero_grad()
        x = nn.GeometricTensor( img , feat_type_in  )
       
        loss , a = induced_arch.compute_loss( x , label ) 
        loss.backward()
        optimizer.step()
        
        ### print model parameters
        
        print(loss)
        
    print()
    print('Epoch Number:' , epoch  )
    print("Training Loss:" , loss )
    print()

        

### save model to file
file = 'Induced_I2S_model.pt'
torch.save(induced_arch, PATH)

### post training
test_dataloader = DataLoader( data_set, batch_size=1, shuffle=True)
with torch.no_grad():
    for img, label in train_dataloader:
        
        img = torch.unsqueeze(img, 1)  
        x = nn.GeometricTensor( img , feat_type_in  )
        y = induced_arch.forward(x)
        
        
        logits = torch.matmul(y, output_wigners).squeeze(1)
        probs = torch.nn.Softmax(dim=1)(logits)        

        plot_so3_distribution(probs[0], output_rotmats, gt_rotation=label)


    


# Training the I2S Model

In [None]:
optimizer = torch.optim.SGD( standard_i2s.parameters(), lr=0.003, momentum=0.03  )

rep_in = [ SO2_act.irrep(0) ]
feat_type_in = nn.FieldType( SO2_act, rep_in  )

num_epoch = 3
for epoch in range(num_epoch):
    for img, label in train_dataloader:
    
        print(img.shape)

        img = torch.unsqueeze(img, 1)
        
        optimizer.zero_grad()
        x = nn.GeometricTensor( img , feat_type_in  )
       
        loss , a = standard_i2s.compute_loss( x , label ) 
        loss.backward()
        optimizer.step()
        
        ### print model parameters
        
        print(loss)
        
    print()
    print('Epoch Number:' , epoch  )
    print("Training Loss:" , loss )
    print()

        