In [None]:
import math
import gzip
import pickle
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.autograd import Variable
from torch.distributions import Normal

from pytorch_util import MLP
from reparameterize import Nreparameterize
from reparameterize import N0reparameterize
from reparameterize import SO3reparameterize
from pytorch_util import logsumexp

import matplotlib.pyplot as plt

from pytorch_util import load_mnist_data, n2p

from vae import VAE

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', '-g', type=int, default=0, help='')
parser.add_argument('--epochs', '-e', type=int, default=10, help='')
parser.add_argument('--batch_dim', '-bd', type=int, default=32, help='')
parser.add_argument('--learning_rate', '-lr', type=float, default=1e-3, help='')

FLAGS, unparsed = parser.parse_known_args()

In [None]:
torch.cuda.set_device(FLAGS.gpu)

In [None]:
DIGITS = '3c1s10000r/'
train_imgs = np.load(('../data/' + DIGITS + 'rmnist.digits.train.npy'))
train_labels = np.load(('../data/' + DIGITS + 'rmnist.labels.train.npy'))
val_imgs = np.load(('../data/' + DIGITS + 'rmnist.digits.validation.npy'))
val_labels = np.load(('../data/' + DIGITS + 'rmnist.labels.validation.npy'))
test_imgs = np.load(('../data/' + DIGITS + 'rmnist.digits.test.npy'))
test_labels = np.load(('../data/' + DIGITS + 'rmnist.labels.test.npy'))

In [None]:
class VAE2D(VAE):
    def __init__(self, 
                 input_dim= 28**2,
                 z_dims=[10, 9],
                 encoder_mlp_h=[100],
                 latents = ['gaussian', 'gaussian'],
                 decoder_mlp_h=[100]):
        super(VAE2D, self).__init__()

        self.encoder_mlp_h = encoder_mlp_h.copy()
        self.encoder_mlp_h.insert(0, input_dim)
        self.encoder = MLP(self.encoder_mlp_h)
        self.z_dims = z_dims
        
        self.reparameterize = []
        self.r_callback = []
        for i, (l, z) in enumerate(zip(latents, z_dims)):
            # no callback needed
            self.r_callback.append(lambda x : x)
            
            if l == 'gaussian':
                reparam = Nreparameterize(encoder_mlp_h[-1], z)
            elif l == 'so3':
                assert z == 9 #the 3x3 lie group element will be concatenated 
                reparam_g = N0reparameterize(encoder_mlp_h[-1], 3)
                reparam = SO3reparameterize(reparam_g, k=10)
            else:
                print ('!!! please specify latent')
                raise RuntimeError 
                
            self.add_module(('latent%d' % i),reparam)
            self.reparameterize.append(reparam)

        z_dim_out = sum(z_dims)
        self.decoder_mlp_h = decoder_mlp_h.copy()
        self.decoder_mlp_h.insert(0, z_dim_out)
        self.decoder_mlp_h.append(input_dim)
        self.decoder = MLP(self.decoder_mlp_h)
        
    def recon_loss(self, x_recon, x):
        x = x.expand_as(x_recon)
        max_val = (-x_recon).clamp(min=0)
        loss = x_recon - x_recon * x + max_val + ((-max_val).exp() + \
                                               (-x_recon - max_val).exp()).log()        
        return loss.sum(-1).sum(0)

In [None]:
vae = VAE2D(input_dim= 28**2,
            z_dims=[9],
            encoder_mlp_h=[256, 128],
            latents = ['so3'],
            decoder_mlp_h=[128])

print("#params", sum([x.numel() for x in vae.parameters()]))

if torch.cuda.is_available():
    vae.cuda(FLAGS.gpu)

optimizer = torch.optim.Adam(vae.parameters(), lr=FLAGS.learning_rate)

In [None]:
epoch_size = ((train_labels.shape[0]) // FLAGS.batch_dim)
max_it = int(FLAGS.epochs * epoch_size)

for epoch in range(100):#range(FLAGS.epochs):
    print ('\nepoch: %d/%d' % (epoch+1, FLAGS.epochs))
    ave_epoch_loss = []
    for i in range(max_it):
        optimizer.zero_grad()
        
        idx = np.random.choice(train_labels.shape[0], FLAGS.batch_dim, replace=True)
        x_mb = train_imgs[idx]
        x_mb = x_mb > np.random.uniform(size=x_mb.shape)
        y_mb = train_labels[idx]
        
        images = n2p(x_mb)
        labels = n2p(y_mb)

        if torch.cuda.is_available():
            images = images.cuda(FLAGS.gpu)
            labels = labels.cuda(FLAGS.gpu)
        
        rec, kl = vae.elbo(images, n=1)
        kl_w = 1. * float(np.minimum(0., i / (max_it)))
        elbo = rec.mean() + kl_w * kl.mean()
        ave_epoch_loss.append(rec.cpu().data.numpy().mean())
        elbo.backward()

        optimizer.step()
        
        print('\r it: %d/%d \t rec: %4.3f \t kl: %4.3f' % 
              (i, epoch_size, rec.cpu().data.numpy().mean(), kl.cpu().data.numpy().mean()), end='')
        #print(vae.reparameterize[0].reparameterize.sigma.cpu().max().data.numpy())
        
        
    print ('\r ave elbo: %4.3f\n' % (np.asarray(ave_epoch_loss).mean()), end='')

In [None]:
b = 16
idx = np.random.randint(0,b)
plt.imshow(images[idx,:].cpu().data.numpy().reshape(28,28), cmap="gray")
plt.show()
plt.imshow(F.sigmoid(vae(images)[0, idx, :]).cpu().data.numpy().reshape(28,28), cmap="gray")
plt.show()

In [None]:
c = np.random.normal(0,1,(1,1,3))
n = 20
cs = np.tile(c, (1,n,1))

In [None]:
cs[0, :, 1] = cs[0, :, 1] + np.linspace(0,2,n)

comp = 'so3'
for i in range(cs.shape[1]):
    c = cs[0,i,:]
    c = c[np.newaxis, np.newaxis, :]
    if comp == 'so3':
        #c = np.random.normal(0,1,(1,1,3))
        z = SO3reparameterize._expmap_rodrigues(n2p(c))
        z = z.view(*z.size()[:-2],-1)
        xz = vae.decode(z.cuda())
    elif comp == 'mix':
        c1 = n2p(np.random.normal(0,1,(1,1, vae.z_dims[0])))
        c2 = np.random.normal(0,1,(1,1,3))
        c2 = SO3reparameterize._expmap_rodrigues(n2p(c2))
        c2 = c2.view(*c2.size()[:-2],-1)
        z = torch.cat([c1,c2], -1)
        xz = vae.decode(z.cuda())
    else:
        c = np.random.normal(0,1,(1,1,vae.z_dims[0]))
        xz = vae.decode(n2p(c).cuda())

    plt.imshow(F.sigmoid(xz[0,0,:]).cpu().data.numpy().reshape(28,28), cmap="gray")
    plt.show()

In [None]:
vae.log_likelihood(images.cuda(FLAGS.gpu), n=2)

In [None]:
c0 = np.random.normal(0,1,(1, 1, 10))
z1 = n2p(c0)
z1.size()
c1 = np.random.normal(0,1,(1, 1, 3))

In [None]:
c1 = c1 + np.random.normal(0,0.3,(1, 1, 3))
z2 = SO3reparameterize._expmap_rodrigues(n2p(c1))
z2.size()

# c2 = np.random.normal(0,1,(1, 1, 3))
# z2 = SO3reparameterize._expmap_rodrigues(n2p(c2))
# z2.size()

z = torch.cat([z1, z2.view(*z2.size()[:-2], -1)],-1)
z.size()
s = vae.decode(z.cuda())
print(s.size())
plt.imshow(s[0,0,0].cpu().data.numpy(), cmap="gray")

In [None]:
vae.reparameterize