In [None]:
import numpy as np
import torch.nn as nn
from s2cnn.nn.soft.so3_conv import SO3Convolution
from s2cnn.nn.soft.s2_conv import S2Convolution
from s2cnn.nn.soft.so3_integrate import so3_integrate
from s2cnn.ops.so3_localft import near_identity_grid as so3_near_identity_grid
from s2cnn.ops.s2_localft import near_identity_grid as s2_near_identity_grid
import torch.nn.functional as F
import torch
import torch.utils.data as data_utils
import gzip, pickle
import numpy as np
from torch.autograd import Variable
from torch.distributions import Normal

import matplotlib.pyplot as plt
import math

In [None]:
 def n2p(x, requires_grad = True):
    """converts numpy tensor to pytorch variable"""
    return Variable(torch.Tensor(x), requires_grad)

# https://github.com/pytorch/pytorch/issues/2591
def logsumexp(inputs, dim=None, keepdim=False):
    """Numerically stable logsumexp.

    Args:
        inputs: A Variable with any shape.
        dim: An integer.
        keepdim: A boolean.

    Returns:
        Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
    """
    # For a 1-D array x (any array along a single dimension),
    # log sum exp(x) = s + log sum exp(x - s)
    # with s = max(x) being a common choice.
    if dim is None:
        inputs = inputs.view(-1)
        dim = 0
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
    if not keepdim:
        outputs = outputs.squeeze(dim)
    return outputs

class SO3reparameterize(nn.Module):
    def __init__(self,input_dim, k=10):
        super(SO3reparameterize, self).__init__()
            
        self.input_dim = input_dim
        self.z_dim = 3
        self.k = k
        
        self.mu_linear = nn.Linear(input_dim, 3)
        self.Ldiag_linear = nn.Linear(input_dim, 3)
        self.Lnondiag_linear = nn.Linear(input_dim, 3)
          
    @staticmethod
    def _lieAlgebra(v):
        """Map a point in R^N to the tangent space at the identity, i.e. 
        to the Lie Algebra
        Arg:
            v = vector in R^N, (..., 3) in our case
        Return:
            R = v converted to Lie Algebra element, (3,3) in our case"""
        R_x = n2p(np.array([[ 0., 0., 0.],[ 0., 0.,-1.],[ 0., 1., 0.]]))
        R_y = n2p(np.array([[ 0., 0., 1.],[ 0., 0., 0.],[-1., 0., 0.]]))
        R_z = n2p(np.array([[ 0.,-1., 0.],[ 1., 0., 0.],[ 0., 0., 0.]]))

        R = R_x * v[..., 0, None, None] + R_y * v[..., 1, None, None] + \
            R_z * v[..., 2, None, None]
        return R
    
    @staticmethod
    def _expmap_rodrigues(v):
        theta = v.norm(p=2,dim=-1, keepdim=True)
        K = SO3reparameterize._lieAlgebra(v/theta)
        I = Variable(torch.eye(3))
        R = I + torch.sin(theta)[...,None]*K + \
                (1. - torch.cos(theta))[...,None]*(K@K)
        a = torch.sin(theta)[...,None]
        return R
    
    def forward(self, x, n=1):
        self.mu = self.mu_linear(x)
        self.D = F.softplus(self.Ldiag_linear(x))
        L = self.Lnondiag_linear(x)

        self.L = torch.cat((Variable(torch.ones(torch.Size((*self.D.size()[:-1], 1)))),
                       Variable(torch.zeros(torch.Size((*self.D.size()[:-1], 2)))),
                       L[...,0].unsqueeze(-1),
                       Variable(torch.ones(torch.Size((*self.D.size()[:-1], 1)))),
                       Variable(torch.zeros(torch.Size((*self.D.size()[:-1], 1)))),
                       L[...,1:],
                       Variable(torch.ones(torch.Size((*self.D.size()[:-1], 1))))), -1).view(
            torch.Size((*self.D.size()[:-1], 3, 3)))
        
        self.v, self.z = self.nsample(self.mu, self.L, self.D, n = n)
        
        return self.z
    
    def kl(self):
        kl = 0
        return kl
            
    def log_posterior(self):
        theta = self.v.norm(p=2,dim=-1, keepdim=True)
        u = self.v / theta
        angles = Variable(torch.arange(-self.k, self.k+1) * 2 * math.pi)
        theta_hat = theta[...,None] + angles
        x = u[...,None] * theta_hat

        L_hat = self.L - Variable(torch.eye(3))
        L_inv = Variable(torch.eye(3)) - L_hat + L_hat@L_hat
        D_inv = 1. / self.D
        A = L_inv @ x

        p = -0.5*(A * D_inv[...,None] * A + 2 * torch.log(theta_hat.abs()) -\
                          torch.log(2 - 2 * torch.cos(theta_hat))).sum(-2) 
        p = logsumexp(p, -1)
        p += -0.5*(torch.log(self.D.prod(-1)) + self.v.size()[-1]*math.log(2.*math.pi))

        return p
        
    def log_prior(self):
        # To DO :
        return 1 / (8 * math.pi**2)
        
    @staticmethod
    def nsample(mu, L, D, n=1):
        # reproduce the decomposition of L-D we make
        eps = Normal(torch.zeros_like(mu), torch.ones_like(mu)).sample_n(n) 
        v = (L @ (D.pow(0.5)*eps)[..., None]).squeeze(-1)
        mu_lie = SO3reparameterize._expmap_rodrigues(mu)
        v_lie = SO3reparameterize._expmap_rodrigues(v)
        return v, mu_lie @ v_lie
    
s3 = SO3reparameterize(10)

In [None]:
def uniform_over_ball(n, r=math.pi, d=3):
    a = np.random.normal(0,1,(n,d))
    a = a/np.linalg.norm(a,2,-1, True)
    u = np.random.uniform(0,1,(n,1))**(1/d)
    
    #using https://www.sciencedirect.com/science/article/pii/S0047259X10001211
    return a*u*r, u*r

In [None]:
def nprodrigues(v):
    return SO3reparameterize._expmap_rodrigues(n2p(v)).data.numpy()


def E(samples, f, w = np.array([1])):
    return (f(samples)*w[:,None,None]).mean(0)

def volume(s_norm):
    return np.squeeze(2*(1 - np.cos(s_norm))/(s_norm**2))


def rot2norm(R):  
    a = np.trace(R, axis1=-2, axis2=-1)/2 - 1/2
    a = np.clip(a, -1, 1)
    return np.absolute(np.arccos(a))

def rot2vol(R):
    return volume(rot2norm(R))
    

In [None]:
s, s_norm = uniform_over_ball(100)

In [None]:
s_norm = np.squeeze(s_norm)

In [None]:
s_norm.shape

In [None]:
mat = np.random.uniform(0,10, (1,3,3))

In [None]:
s, s_norm = uniform_over_ball(100000)

s_norm = np.squeeze(s_norm)

g = nprodrigues(s)

vol = volume(s_norm)

In [None]:
f = lambda x: mat.transpose(0,2,1)@(x@mat)
E(g,f,vol)

In [None]:
mu = uniform_over_ball(1)[0]
mu_lie = nprodrigues(mu)
g_rot = mu_lie@g
vol_rot = rot2vol(g_rot)

In [None]:
E(g_rot,f,vol)