In [9]:
import numpy as np 
import math

import torch 
from torch import nn
from torch.autograd import Variable 
from torch import Tensor as t
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim import Adam 

from torch.utils import data
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

In [118]:
def n2p(x, requires_grad = True):
    """converts numpy tensor to pytorch variable"""
    return Variable(t(x), requires_grad)

def t2c(x):
    return x.cuda()

# https://github.com/pytorch/pytorch/issues/2591
def logsumexp(inputs, dim=None, keepdim=False):
    """Numerically stable logsumexp.

    Args:
        inputs: A Variable with any shape.
        dim: An integer.
        keepdim: A boolean.

    Returns:
        Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
    """
    # For a 1-D array x (any array along a single dimension),
    # log sum exp(x) = s + log sum exp(x - s)
    # with s = max(x) being a common choice.
    if dim is None:
        inputs = inputs.view(-1)
        dim = 0
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
    if not keepdim:
        outputs = outputs.squeeze(dim)
    return outputs

def map2LieAlgebra(v):
    """Map a point in R^N to the tangent space at the identity, i.e. 
    to the Lie Algebra
    Arg:
        v = vector in R^N, (..., 3) in our case
    Return:
        R = v converted to Lie Algebra element, (3,3) in our case"""
    
    # make sure this is a sample from R^3
    assert v.size()[-1] == 3
    
    R_x = n2p(np.array([[ 0., 0., 0.],
                        [ 0., 0.,-1.],
                        [ 0., 1., 0.]]))
    
    R_y = n2p(np.array([[ 0., 0., 1.],
                        [ 0., 0., 0.],
                        [-1., 0., 0.]]))
    
    R_z = n2p(np.array([[ 0.,-1., 0.],
                        [ 1., 0., 0.],
                        [ 0., 0., 0.]]))
    
    R = R_x * v[..., 0, None, None] + \
        R_y * v[..., 1, None, None] + \
        R_z * v[..., 2, None, None]
    return R

def rodrigues(v):
    theta = v.norm(p=2,dim=-1, keepdim=True)
    # normalize K
    K = map2LieAlgebra(v/theta)
    
    I = Variable(torch.eye(3))
    R = I + torch.sin(theta)[...,None]*K + (1. - torch.cos(theta))[...,None]*(K@K)
    a = torch.sin(theta)[...,None]
    return R

def log_density(v, L, D, k = 10):
    theta = v.norm(p=2,dim=-1, keepdim=True)
    u = v / theta
    angles = Variable(torch.arange(-k, k+1) * 2 * math.pi)
    theta_hat = theta[...,None] + angles
    
#     print(theta_hat.min())
    
    x = u[...,None] * theta_hat
    
    L_hat = L - Variable(torch.eye(3))
    L_inv = Variable(torch.eye(3)) - L_hat + L_hat @ L_hat
    D_inv = 1. / D
    A = L_inv @ x
    
#     print(A * D_inv[...,None] * A + 2 )
#     print(theta)
    
    p = -0.5*(A * D_inv[...,None] * A + 2 * torch.log(theta_hat.abs()) - torch.log(2 - 2 * torch.cos(theta_hat)) ).sum(-2) 
    p = logsumexp(p, -1)
    p += -0.5*(torch.log(D.prod(-1)) + v.size()[-1]*math.log(2.*math.pi))*(2*k + 1)
    return p

def randomR():
    q, r = np.linalg.qr(np.random.normal(size=(3, 3)))
    r = np.diag(r)
    ret = q @ np.diag(r / np.abs(r))
    return ret * np.linalg.det(ret)

In [119]:
class Encoder(nn.Module):
    def __init__(self, n_hidden):
        super(Encoder, self).__init__()
        self.hidden_1 = nn.Linear(784, n_hidden)
        self.hidden_2 = nn.Linear(n_hidden, n_hidden)
        self.hidden_mu = nn.Linear(n_hidden, 3)
        self.hidden_Ldiag = nn.Linear(n_hidden, 3)
        self.hidden_Lnondiag = nn.Linear(n_hidden, 3)

    def forward(self, x):
        h0 = F.tanh(self.hidden_1(x))
        h1 = F.tanh(self.hidden_2(h0))
        
        mu = self.hidden_mu(h1)
        D = F.softplus(self.hidden_Ldiag(h1))
        L = self.hidden_Lnondiag(h1)

        L = torch.cat((Variable(torch.ones(torch.Size((*D.size()[:-1], 1)))),
                Variable(torch.zeros(torch.Size((*D.size()[:-1], 2)))),
                L[...,0].unsqueeze(-1),
                Variable(torch.ones(torch.Size((*D.size()[:-1], 1)))),
                Variable(torch.zeros(torch.Size((*D.size()[:-1], 1)))),
                L[...,1:],
                Variable(torch.ones(torch.Size((*D.size()[:-1], 1))))), -1).view(
            torch.Size((*D.size()[:-1], 3, 3)))

        return mu, L, D

class Decoder(nn.Module):
    def __init__(self, n_hidden):
        super(Decoder, self).__init__()
        self.hidden_1 = nn.Linear(9, n_hidden)
        self.hidden_2 = nn.Linear(n_hidden, 784)
        
    def forward(self, x):
        h0 = F.tanh(self.hidden_1(x))
        h1 = F.sigmoid(self.hidden_2(h0))
        
        return h1

In [134]:
enc = Encoder(128)
dec = Decoder(128)

In [135]:
optimizer = Adam(list(enc.parameters()) + list(dec.parameters()))

In [136]:
cuda = False
batch_size = 32
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=  transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [137]:
# list(train_loader)[0]

In [143]:
for idx, (x, t_) in enumerate(train_loader):
    optimizer.zero_grad()
    
    x = Variable(x).view(-1, 784)
    mu, L, D = enc(x)
    
    noise = Variable(Normal(t(np.zeros(3)), t(np.ones(3))).sample_n(batch_size))    
    v = (L @ (D.pow(0.5)*noise)[..., None]).squeeze()
    
    H = -log_density(v, L, D, k = 10)
    
    mu_lie = rodrigues(mu)
    v_lie = rodrigues(v)
    g_lie = mu_lie @ v_lie
    z_rot = g_lie
    x_ = dec(z_rot.view(-1, 9))
    
    loss_bce = nn.BCELoss(size_average=False)(x_, x) / batch_size
    loss_H = H.mean()
    loss = loss_bce - 0 * loss_H
    
    loss.backward()
    optimizer.step()
    
    print('\r', loss_bce.data.numpy()[0], loss_H.data.numpy()[0], end='')
    
    z_ = np.array([randomR() for _ in range(32)])
    _plot_digits(8, 4, dec(n2p(z_).view(-1, 9)).data.numpy())
    _plot_digits(8, 4, x_.data.numpy())
    print(H)
    lol()

 154.48254 nan4.220587

RuntimeError: Assertion `x >= 0. && x <= 1.' failed. input value should be between 0~1, but got -nan at /opt/conda/conda-bld/pytorch-cpu_1515613813020/work/torch/lib/THNN/generic/BCECriterion.c:34

In [139]:
def _plot_digits(w, h, x):
    h_, w_ = (28, 28)

    plt.figure(figsize=(w, h))

    image = np.zeros((h_ * h, w_ * w))
    for i in range(h):
        for j in range(w):
            image[h_ * i : h_ * (i + 1), w_ * j : w_ * (j + 1)] = x[i * w + j].reshape(28,28)

    plt.imshow(image, interpolation='none', cmap='gray')
    plt.axis('off')

    plt.show()

(32, 3, 3)