In [None]:
import numpy as np 
import torch 
from torch import nn
from torch.autograd import Variable 
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
class Net(nn.Module):
    def __init__(self, n_hidden):
        super(Net, self).__init__()
        self.hidden_1 = nn.Linear(2 * 5 * 3, n_hidden)
        self.hidden_2 = nn.Linear(n_hidden, n_hidden)
        self.hidden_mu = nn.Linear(n_hidden, 3)
        self.hidden_Ldiag = nn.Linear(n_hidden, 3)
        self.hidden_Lnondiag = nn.Linear(n_hidden, 3)

    def forward(self, x):
        h0 = F.tanh(self.hidden_1(x))
        h1 = F.tanh(self.hidden_2(h0))
        
        mu = self.hidden_mu(h1)
        D = F.softplus(self.hidden_Ldiag(h1))
        L = self.hidden_Lnondiag(h1)

        L = torch.cat((Variable(torch.ones(torch.Size((*D.size()[:-1], 1)))),
                Variable(torch.zeros(torch.Size((*D.size()[:-1], 2)))),
                L[...,0].unsqueeze(-1),
                Variable(torch.ones(torch.Size((*D.size()[:-1], 1)))),
                Variable(torch.zeros(torch.Size((*D.size()[:-1], 1)))),
                L[...,1:],
                Variable(torch.ones(torch.Size((*D.size()[:-1], 1))))), -1).view(
            torch.Size((*D.size()[:-1], 3, 3)))

        return mu, L, D

In [None]:
net = Net(32)

In [None]:
x_mb = Variable(torch.FloatTensor(np.random.uniform(size=(10, 2 * 5 * 3))))

In [None]:
net(x_mb)

In [None]:
def deg2coo(x):
    theta = x.T[0]
    phi = x.T[1]

    xs = np.sin(phi) * np.sin(theta)
    ys = np.cos(phi) * np.sin(theta)
    zs = np.cos(theta)

    return np.vstack((xs, ys, zs)).T

def randomR():
    q, r = np.linalg.qr(np.random.normal(size=(3, 3)))
    r = np.diag(r)
    return q @ np.diag(r / np.abs(r))

def next_batch(batch_dim):
    a = np.pi / 12
    canonicalL = deg2coo(np.array([[0, 0],
                                   [a / 2, np.pi/2],
                                   [a, np.pi/2],
                                   [a / 2, 0],
                                   [a, 0]]))

    originalL = np.stack([canonicalL @ randomR() for _ in range(batch_dim)])
    rotations = np.stack([randomR() for _ in range(batch_dim)])
    rotatedL = np.stack([oL @ rot for oL, rot in zip(originalL, rotations)])

    x_mb = np.hstack((originalL.reshape(-1, 5 * 3), rotatedL.reshape(-1, 5 * 3)))
    
    return x_mb, rotations

In [None]:
x_mb, y_mb = next_batch(2)

In [None]:
# x_mb.shape, y_mb.shape

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
ax.set_aspect('equal')

xs, ys, zs = x_mb[...,:15].reshape(-1, 3).T
ax.scatter(xs, ys, zs)

u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)

x = 1 * np.outer(np.cos(u), np.sin(v))
y = 1 * np.outer(np.sin(u), np.sin(v))
z = 1 * np.outer(np.ones(np.size(u)), np.cos(v))

ax.plot_surface(x, y, z,  rstride=4, cstride=4, color='#DAE8FC', linewidth=0, alpha=0.3)
ax.grid(False)
plt.axis('off')
plt.show()