In [None]:
import numpy as np 
import math

import torch 
from torch import nn
from torch.autograd import Variable 
from torch import Tensor as t
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim import Adam 

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline 

#### Utils

In [None]:
def n2p(x, requires_grad = True):
    """converts numpy tensor to pytorch variable"""
    return Variable(t(x), requires_grad)

def t2c(x):
    return x.cuda()

# https://github.com/pytorch/pytorch/issues/2591
def logsumexp(inputs, dim=None, keepdim=False):
    """Numerically stable logsumexp.

    Args:
        inputs: A Variable with any shape.
        dim: An integer.
        keepdim: A boolean.

    Returns:
        Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
    """
    # For a 1-D array x (any array along a single dimension),
    # log sum exp(x) = s + log sum exp(x - s)
    # with s = max(x) being a common choice.
    if dim is None:
        inputs = inputs.view(-1)
        dim = 0
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
    if not keepdim:
        outputs = outputs.squeeze(dim)
    return outputs

def deg2coo(x):
    theta = x.T[0]
    phi = x.T[1]

    xs = np.sin(phi) * np.sin(theta)
    ys = np.cos(phi) * np.sin(theta)
    zs = np.cos(theta)

    return np.vstack((xs, ys, zs)).T

def randomR():
    q, r = np.linalg.qr(np.random.normal(size=(3, 3)))
    r = np.diag(r)
    ret = q @ np.diag(r / np.abs(r))
    return ret * np.linalg.det(ret)

def next_batch(batch_dim):
    a = np.pi / 12
    canonicalL = deg2coo(np.array([[0, 0],
                                   [a / 2, np.pi/2],
                                   [a, np.pi/2],
                                   [a / 2, 0],
                                   [a, 0]]))

    originalL = np.stack([canonicalL @ randomR() for _ in range(batch_dim)])
    rotations = np.stack([randomR() for _ in range(batch_dim)])
    rotatedL = np.stack([oL @ rot for oL, rot in zip(originalL, rotations)])

    return originalL, rotatedL, rotations

def get_sample(N, x_mb):
    mu, L, D = N(n2p(x_mb))
    noise = Variable(Normal(t(np.zeros(3)), t(np.ones(3))).sample_n(1))    
    v = (L @ (D.pow(0.5)*noise)[..., None]).squeeze()
    mu_lie = rodrigues(mu)
    v_lie = rodrigues(v)
    g_lie = mu_lie @ v_lie

    xrot_recon = (n2p(xo) @ g_lie).data.numpy()
    return xrot_recon

#### MLP outputs $L$, $\epsilon$ and $\mu$

In [None]:
class Net(nn.Module):
    def __init__(self, n_hidden):
        super(Net, self).__init__()
        self.hidden_1 = nn.Linear(2 * 5 * 3, n_hidden)
        self.hidden_2 = nn.Linear(n_hidden, n_hidden)
        self.hidden_mu = nn.Linear(n_hidden, 3)
        self.hidden_Ldiag = nn.Linear(n_hidden, 3)
        self.hidden_Lnondiag = nn.Linear(n_hidden, 3)

    def forward(self, x):
        h0 = F.tanh(self.hidden_1(x))
        h1 = F.tanh(self.hidden_2(h0))
        
        mu = self.hidden_mu(h1)
        D = F.softplus(self.hidden_Ldiag(h1))
        L = self.hidden_Lnondiag(h1)

        L = torch.cat((Variable(torch.ones(torch.Size((*D.size()[:-1], 1)))),
                Variable(torch.zeros(torch.Size((*D.size()[:-1], 2)))),
                L[...,0].unsqueeze(-1),
                Variable(torch.ones(torch.Size((*D.size()[:-1], 1)))),
                Variable(torch.zeros(torch.Size((*D.size()[:-1], 1)))),
                L[...,1:],
                Variable(torch.ones(torch.Size((*D.size()[:-1], 1))))), -1).view(
            torch.Size((*D.size()[:-1], 3, 3)))

        return mu, L, D

#### Convert outputs in $\mathbb{R}^3$ to Lie algebra

In [None]:
def map2LieAlgebra(v):
    """Map a point in R^N to the tangent space at the identity, i.e. 
    to the Lie Algebra
    Arg:
        v = vector in R^N, (..., 3) in our case
    Return:
        R = v converted to Lie Algebra element, (3,3) in our case"""
    
    # make sure this is a sample from R^3
    assert v.size()[-1] == 3
    
    R_x = n2p(np.array([[ 0., 0., 0.],
                        [ 0., 0.,-1.],
                        [ 0., 1., 0.]]))
    
    R_y = n2p(np.array([[ 0., 0., 1.],
                        [ 0., 0., 0.],
                        [-1., 0., 0.]]))
    
    R_z = n2p(np.array([[ 0.,-1., 0.],
                        [ 1., 0., 0.],
                        [ 0., 0., 0.]]))
    
    R = R_x * v[..., 0, None, None] + \
        R_y * v[..., 1, None, None] + \
        R_z * v[..., 2, None, None]
    return R

# x = Normal(t([0., 0., 0.]), t([1., 1., 1.])).sample_n(3)
# v = Variable(x)
# l = map2LieAlgebra(v)

#### Use a exponential map, $exp(\cdot)$, to convert elements of Lie algebra to Lie group

In [None]:
def rodrigues(v):
    theta = v.norm(p=2,dim=-1, keepdim=True)
    # normalize K
    K = map2LieAlgebra(v/theta)
    
    I = Variable(torch.eye(3))
    R = I + torch.sin(theta)[...,None]*K + (1. - torch.cos(theta))[...,None]*(K@K)
    a = torch.sin(theta)[...,None]
    return R

# x = Normal(t([0., 0., 0.]), t([1., 1., 1.])).sample_n(2)
# v = Variable(x)
# R = rodrigues(v)

#### Log Density, not using the Jacobian change of coordinates
$$\sum_{k=-\infty}^{\infty} \mathcal{N}(u (2\pi k + \theta) \;|\;0,1)$$

In [None]:
def log_density(v, L, D, k = 10):
    theta = v.norm(p=2,dim=-1, keepdim=True)
    u = v / theta
    angles = Variable(torch.arange(-k, k+1) * 2 * math.pi)
    theta_hat = theta[...,None] + angles
    x = u[...,None] * theta_hat
    
    L_hat = L - Variable(torch.eye(3))
    L_inv = Variable(torch.eye(3)) - L_hat + L_hat@L_hat
    D_inv = 1. / D
    A = L_inv @ x
    
    p = -0.5*(A * D_inv[...,None] * A).sum(-2)
    p = logsumexp(p, -1)
    p += -0.5*(torch.log(D.prod(-1)) + v.size()[-1]*math.log(2.*math.pi))*(2*k + 1)
    return p
    
# x = Normal(t([0., 0., 0.]), t([1., 1., 1.])).sample_n(2)
# v = Variable(x)
# L = Variable((torch.rand(3,3).tril(-1) + torch.eye(3)).repeat(2, 1, 1))
# D = Variable(torch.rand(2, 3))
# log_density(v, L, D)

In [None]:
N = Net(128)
optimizer = Adam(N.parameters())

In [None]:
batch_size = 100
num_steps = 5000
loss_plot = []
for i in range(num_steps):
    optimizer.zero_grad()
    xo_mb, xrot_mb, y_mb = next_batch(batch_size)
    x_mb = np.hstack((xo_mb.reshape(-1, 5 * 3), xrot_mb.reshape(-1, 5 * 3)))
    
    mu, L, D = N(n2p(x_mb))
    noise = Variable(Normal(t(np.zeros(3)), t(np.ones(3))).sample_n(batch_size))    
    v = (L @ (D.pow(0.5)*noise)[..., None]).squeeze()
    
    H = -log_density(v, L, D, k = 10)
    
    mu_lie = rodrigues(mu)
    v_lie = rodrigues(v)
    g_lie = mu_lie @ v_lie
    z_rot = g_lie
    
    xrot_recon_mb = n2p(xo_mb) @ z_rot
    # Spherical Loss
    L_rec = ((xrot_recon_mb * n2p(xrot_mb)).sum(-1) * 0.999).acos().sum(-1).mean()
    # L2 loss
#     L_rec = (( xrot_recon_mb- n2p(xrot_mb))**2).sum(-1).sum(-1).mean()
    # weighting the entropy term to not be too strong
    kl_w = 0.#1e-2
    L = L_rec - torch.mean(H)*kl_w
    L.backward()
    optimizer.step()
    print ('\r (%d/%d) L: %.3f \t D: %.3f ' % 
           (i, num_steps, L_rec.data.numpy(), np.mean(D.data.numpy())), end='')
    loss_plot.append(L_rec.data.numpy())
    

In [None]:
plt.plot(loss_plot, color='y')

#### Sample Plots

In [None]:
def plot_hammer(data):
    plt.figure(figsize=(12,12))
    plt.subplot(111, projection='hammer')

    cols = ('blue', 'red', 'green', 'cyan', 'magenta', 'pink', 'black', 'purple', 'brown', 'orange')
    for i in range(data.shape[0]):
        xs_, ys_, zs_ = data[i,:].reshape(-1, 3).T
        ys = np.arccos(zs_) - np.pi/2
        xs = np.arctan2(ys_, xs_)

        plt.scatter(xs, ys, color=cols[i], label=str(i))

    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    # plt.xticks_labels([], [])
    # plt.yticks([], [])
    # plt.axis('off')
    plt.show()
    
def plot_sphere():
    #plot sphere
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)

    x = 1 * np.outer(np.cos(u), np.sin(v))
    y = 1 * np.outer(np.sin(u), np.sin(v))
    z = 1 * np.outer(np.ones(np.size(u)), np.cos(v))

    ax.plot_surface(x, y, z,  rstride=4, cstride=4, color='#DAE8FC', linewidth=0, alpha=0.3)

In [None]:
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(111, projection='3d')
ax.set_aspect('equal')

# get an input and rotation
xo, xrot, y = next_batch(1)
x_mb = np.hstack((xo.reshape(-1, 5 * 3), xrot.reshape(-1, 5 * 3)))
# plot the original and rotation
xs, ys, zs = xo.reshape(-1, 3).T
ax.scatter(xs, ys, zs, c='b')
xs, ys, zs = xrot.reshape(-1, 3).T
ax.scatter(xs, ys, zs, c='r')

# get samples from neural net
n_samples = 1
samples = xrot
for _ in range(n_samples):
    xrot_recon = get_sample(N, x_mb)
    xs, ys, zs = xrot_recon.reshape(-1, 3).T
    ax.scatter(xs, ys, zs)
    
    samples = np.hstack((samples, xrot_recon))

plot_sphere()
    
ax.grid(False)
plt.axis('off')
plt.show()

data = np.hstack((xo, samples)).reshape(2 + n_samples,5,3)
plot_hammer(data)