In [None]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 

In [None]:
import os
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.utils.data.dataset import Dataset
import numpy as np
import matplotlib.pyplot as plt
import argparse
import math

from glob import glob
from torch.utils.data import Dataset, DataLoader
import os.path
from PIL import Image
from lie_learn.groups.SO3 import change_coordinates as SO3_coordinates
# from tensorboardX import SummaryWriter

from lie_vae.vae import CubeVAE
from lie_vae.reparameterize import Nreparameterize, SO3reparameterize, N0reparameterize
from lie_vae.utils import MLP, random_split, View, Flatten
from lie_vae.lie_tools import group_matrix_to_eazyz, block_wigner_matrix_multiply
from lie_vae.datasets import CubeDataset
from lie_vae.nets import CubesDeconvNet, CubesConvNet
from lie_vae.decoders import ActionNet, MLPNet


CUDA = torch.cuda.is_available()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
parser = argparse.ArgumentParser()

parser.add_argument('--latent', '-z', type=str, default='so3', help='normal or so3')
parser.add_argument('--decoder', '-d', type=str, default='mlp', help='mlp or action')
parser.add_argument('--epochs', '-e', type=int, default=1, help='')
parser.add_argument('--load', '-l', type=str, default='', help='')
parser.add_argument('--save', '-s', type=str, default='', help='')
parser.add_argument('--batch_dim', '-b', type=int, default=32, help='')
parser.add_argument('--batch_norm', '-bn', , action='store_true')

FLAGS, unparsed = parser.parse_known_args()

In [None]:
train_dataset = CubeDataset('train')
dev_dataset = CubeDataset('dev')
test_dataset = CubeDataset('test')

train_loader = data_utils.DataLoader(train_dataset, batch_size=FLAGS.batch_dim, shuffle=True)
dev_loader = data_utils.DataLoader(dev_dataset, batch_size=FLAGS.batch_dim, shuffle=True)
test_loader = data_utils.DataLoader(test_dataset, batch_size=FLAGS.batch_dim, shuffle=True)

In [None]:
model = CubeVAE(FLAGS.decoder, FLAGS.latent, FLAGS.batch_norm).to(device)

if FLAGS.load != '':
    model = torch.load(FLAGS.load)
    
optimizer = torch.optim.Adam(model.parameters())

In [None]:
print ("training {} with {}".format(FLAGS.latent, FLAGS.decoder))
for j in range(FLAGS.epochs):
    for i, (images, ) in enumerate(train_loader):
        images = Variable(images).cuda() if CUDA else Variable(images)

        optimizer.zero_grad()
        
        loss_recon, loss_kl = model.elbo(images, n=1)
        
        loss_recon = loss_recon.mean()
        loss_kl = loss_kl.mean()
        loss = loss_recon + loss_kl
        
        loss.backward()
        optimizer.step()
        
        print('\r epoch: {:4}/{:4}, it: {:4}/{:4}: ELBO: {:.4f}, recon: {:.4f}, KL: {:.4f}'.format(
                j, FLAGS.epochs, i, len(train_loader),
                float(loss.data.cpu().numpy()),
                float(loss_recon.data.cpu().numpy()),
                float(loss_kl.data.cpu().numpy())),
            end='')

    if FLAGS.save != '':
        torch.save(model, FLAGS.save)
    
    loss_recon, loss_kl = np.array([[loss_recon.mean().data.cpu().numpy(), loss_kl.mean().data.cpu().numpy()] 
          for loss_recon, loss_kl in model.elbo(Variable(images).cuda() if CUDA else Variable(images))
     for i, (images, ) in enumerate(dev_loader)]).mean(0)
    
    print('\r epoch: {:4}/{:4}, it: {:4}/{:4}: ELBO: {:.4f}, recon: {:.4f}, KL: {:.4f}'.format(
        j, FLAGS.epochs, i, len(train_loader), loss_recon + loss_kl, loss_recon, loss_kl))

In [None]:
loss_recon, loss_kl = np.array([[loss_recon.mean().data.cpu().numpy(), loss_kl.mean().data.cpu().numpy()] 
      for loss_recon, loss_kl in model.elbo(Variable(images).cuda() if CUDA else Variable(images))
 for i, (images, ) in enumerate(test_loader)]).mean(0)

ll = np.array([model.log_likelihood(Variable(images).cuda() if CUDA else Variable(images), n=500).data.cpu().numpy()
 for i, (images, ) in enumerate(test_loader)]).mean()

print("TEST VAE CUBES results for {} with {}  ---  {} epochs ".format(FLAGS.latent, FLAGS.decoder, FLAGS.epochs))
print('ELBO: {:.4f}, recon: {:.4f}, KL: {:.4f}, LL: {:.4f}'.format(
    loss_recon + loss_kl, loss_recon, loss_kl, ll))

In [None]:
img = Variable(next(iter(train_loader))[0][0:1])
rec_img = model(img.cuda() if CUDA else img)

plt.imshow(img.data.cpu().numpy()[0].transpose(1, 2, 0))
plt.show()
plt.imshow(F.sigmoid(rec_img).data.cpu().numpy()[0,0].transpose(1, 2, 0))
plt.show()

## Continuity checks

In [None]:
# degrees = 3
# matrix_dims = (degrees + 1) ** 2
# in_dims = 2
# item_rep = nn.Parameter(torch.randn((matrix_dims, in_dims)))

In [None]:
# N = 1000
# x0 = np.random.normal(0,0.001,(N+1,3))
# print(x0.shape)
# x0 = np.cumsum(np.cumsum(x0,-2), -2)
# print(x0.shape)
# x0 = torch.tensor(x0, dtype = torch.float32)
# group_el = SO3reparameterize._expmap_rodrigues(x0)
# angles = group_matrix_to_eazyz(group_el)


# plt.plot(angles.detach().numpy()[:,0], "o", color = "b")
# plt.show()
# plt.plot(angles.detach().numpy()[:,1], "o", color = "b")
# plt.show()
# plt.plot(angles.detach().numpy()[:,2], "o", color = "b")
# plt.show()

In [None]:

# n = angles.size(0)
# item_expanded = item_rep.expand(n, -1, -1)
# item = block_wigner_matrix_multiply(angles, item_expanded, degrees) \
#             .view(-1, matrix_dims * in_dims)

# i = np.random.randint(item.shape[1])
   
# plt.plot(item.detach().numpy()[:,i], "o", color = "b")
# plt.show()


In [None]:
# torch.arange(-3, 4) * 2 * math.pi