In [None]:
import os
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
import numpy as np
import matplotlib.pyplot as plt
import argparse
import math

from vae import VAE
from reparameterize import Nreparameterize, SO3reparameterize, N0reparameterize

from torch.utils.data.dataset import Dataset

CUDA = torch.cuda.is_available()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
from glob import glob
from torch.utils.data import Dataset, DataLoader
import os.path
from PIL import Image
from lie_learn.groups.SO3 import change_coordinates as SO3_coordinates
from tensorboardX import SummaryWriter


import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 

from utils import MLP, random_split
from lie_tools import group_matrix_to_eazyz, block_wigner_matrix_multiply

In [None]:
parser = argparse.ArgumentParser()

parser.add_argument('--latent', '-z', type=str, default='so3', help='normal or so3')
parser.add_argument('--decoder', '-d', type=str, default='mlp', help='mlp or action')
parser.add_argument('--epochs', '-e', type=int, default=200, help='')
parser.add_argument('--load', '-l', type=str, default='', help='')
parser.add_argument('--save', '-s', type=str, default='', help='')
parser.add_argument('--batch_dim', '-b', type=int, default=32, help='')

FLAGS, unparsed = parser.parse_known_args()

In [None]:
class DeconvNet(nn.Sequential):
    """1x1 to 32x32 deconvolutional stack."""
    def __init__(self, in_dims, hidden_dims):
        super().__init__(
            nn.Linear(in_dims, 32 * 8 * 8 ),
            View(-1, 32, 8, 8),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(),
            nn.Conv2d(32, 16, 3, padding=1),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 3, 3, padding=1)
        )
#         super().__init__(
#             nn.ConvTranspose2d(in_dims, hidden_dims, 4, 1, 0),
#             nn.ReLU(),
#             nn.ConvTranspose2d(hidden_dims, hidden_dims, 4, 2, 1),
#             nn.ReLU(),
#             nn.ConvTranspose2d(hidden_dims, hidden_dims, 4, 2, 1),
#             nn.ReLU(),
#             nn.ConvTranspose2d(hidden_dims, 3, 4, 2, 1),

#         )
        
class ActionNet(nn.Module):
    """Uses proper group action."""
    def __init__(self, degrees, in_dims=10, with_mlp=False):
        super().__init__()
        self.degrees = degrees
        self.in_dims = in_dims
        self.matrix_dims = (degrees + 1) ** 2
        self.item_rep = nn.Parameter(torch.randn((self.matrix_dims, in_dims)))
        self.deconv = DeconvNet(self.matrix_dims * self.in_dims, 50)
        if with_mlp:
            self.mlp = MLP(self.matrix_dims * in_dims,
                           self.matrix_dims * in_dims, 50, 3)
        else:
            self.mlp = None

    def forward(self, angles):
        """Input dim is [batch, 3, 3]."""
        n = angles.size(0)
        item_expanded = self.item_rep.expand(n, -1, -1)

        item = block_wigner_matrix_multiply(angles, item_expanded, self.degrees) \
            .view(-1, self.matrix_dims * self.in_dims)

        if self.mlp:
            item = self.mlp(item)
        out = self.deconv(item[:, :])
        return out[:, :, :, :]
    
class MLPNet(nn.Module):
    """Uses MLP from group matrix."""
    def __init__(self, input_dim, degrees, in_dims=10):
        super().__init__()
        self.input_dim = input_dim
        matrix_dims = (degrees + 1) ** 2
        self.mlp = MLP(self.input_dim, matrix_dims * in_dims, 50, 3)
        self.deconv = DeconvNet(matrix_dims * in_dims, 50)

    def forward(self, x):
        """Input dim is [batch, 3, 3] or [batch, 3]."""
        x = self.mlp(x.view(-1, self.input_dim))
        return self.deconv(x[:, :])[:, :, :, :]


In [None]:
class View(nn.Module):
    def __init__(self, *v):
        super(View, self).__init__()
        self.v = v
    
    def forward(self, x):
        return x.view(*self.v)
    
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    
    def forward(self, x):
        return x.view(x.size(0), -1)

#TODO better reparametrization action normal
class ConvVAE(VAE):
    def __init__(self):
        super(ConvVAE, self).__init__()
        ndf = 16
        self.encoder = nn.Sequential(
            # input is (nc) x 32 x 32
            nn.Conv2d(3, ndf, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, ndf * 8, 4, 1, 0, bias=False),
            Flatten(),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(ndf * 8, ndf * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )

        if FLAGS.decoder == "mlp":
            if FLAGS.latent == "so3":
                self.rep0 = SO3reparameterize(N0reparameterize(ndf * 4, z_dim=3), k=10)
                self.reparameterize = [self.rep0]
                self.r_callback = [torch.nn.Sequential()]
                self.decoder = MLPNet(input_dim=9, degrees=6, in_dims=100)
            elif FLAGS.latent == "normal":
                self.rep0 = Nreparameterize(ndf * 4, 3)
                self.reparameterize = [self.rep0]
                self.r_callback = [torch.nn.Sequential()]
                self.decoder = MLPNet(input_dim=3, degrees=6, in_dims=100)
        elif FLAGS.decoder == "action":
            if FLAGS.latent == "so3":
                self.rep0 = SO3reparameterize(N0reparameterize(ndf * 4, z_dim=3), k=10)
                self.reparameterize = [self.rep0]
                self.r_callback = [torch.nn.Sequential()]
                self.decoder = MLPNet(input_dim=9, degrees=6, in_dims=100)
            elif FLAGS.latent == "normal":
                self.rep0 = Nreparameterize(ndf * 4, 3)
                self.reparameterize = [self.rep0]
                self.r_callback = [torch.nn.Sequential()]
                
            self.decoder = ActionNet(6, 10, with_mlp=True)
    
    def forward(self, x, n=1):
        z_list = self.encode(x)
        z_pose = z_list[0]
        z_pose_ = z_pose.view(-1, *z_pose.shape[2:])
        
        if FLAGS.decoder == "action":
            if FLAGS.latent == "so3":
                angles = group_matrix_to_eazyz(z_pose_)
            elif FLAGS.latent == "normal":            
                angles = F.tanh(z_pose_)
                angles = angles * torch.tensor([[math.pi, math.pi / 2, math.pi]], device=device) + \
                    torch.tensor([[0, math.pi / 2, 0]], device=device)
    
            x_recon = self.decoder(angles).view(*z_pose.shape[:2], 3, 32, 32)
        
        elif FLAGS.decoder == "mlp":
            
            x_recon = self.decode(z_pose_).view(*z_pose.shape[:2], 3, 32, 32)
            
        return x_recon
    
    def recon_loss(self, x_recon, x):
        x = x.expand_as(x_recon)
        max_val = (-x_recon).clamp(min=0)
        loss = x_recon - x_recon * x + max_val + ((-max_val).exp() + (-x_recon - max_val).exp()).log() 
        
        return loss.sum(-1).sum(-1).sum(-1)

In [None]:
train_data = np.load('train_data.npy')
dev_data = np.load('dev_data.npy')
test_data = np.load('test_data.npy')

# train_labels = np.load('train_labels.npy')
# dev_labels = np.load('dev_labels.npy')
# test_labels = np.load('test_labels.npy')

train_data = torch.from_numpy(train_data)
# train_labels = torch.from_numpy(train_labels)

train_dataset = data_utils.TensorDataset(train_data) # , train_labels)
train_loader = data_utils.DataLoader(train_dataset, batch_size=FLAGS.batch_dim, shuffle=True)

dev_data = torch.from_numpy(dev_data)
# dev_labels = torch.from_numpy(dev_labels)

dev_dataset = data_utils.TensorDataset(dev_data) # , dev_labels)
dev_loader = data_utils.DataLoader(dev_dataset, batch_size=FLAGS.batch_dim, shuffle=True)

test_data = torch.from_numpy(test_data)
# test_labels = torch.from_numpy(test_labels)

test_dataset = data_utils.TensorDataset(test_data) # , test_labels)
test_loader = data_utils.DataLoader(test_dataset, batch_size=FLAGS.batch_dim, shuffle=True)

In [None]:
model = ConvVAE().cuda() if CUDA else ConvVAE()

if FLAGS.load != '':
    model = torch.load(FLAGS.load)
    
optimizer = torch.optim.Adam(model.parameters())

In [None]:
print ("training {} with {}".format(FLAGS.latent, FLAGS.decoder))
for j in range(FLAGS.epochs):
    for i, (images, ) in enumerate(train_loader):
        images = Variable(images).cuda() if CUDA else Variable(images)

        optimizer.zero_grad()
        
        loss_recon, loss_kl = model.elbo(images)
        
        loss_recon = loss_recon.mean()
        loss_kl = loss_kl.mean()
        loss = loss_recon + loss_kl
        
        loss.backward()
        optimizer.step()
        
        print('\r epoch: {:4}/{:4}, it: {:4}/{:4}: ELBO: {:.4f}, recon: {:.4f}, KL: {:.4f}'.format(
                j, FLAGS.epochs, i, len(train_loader),
                float(loss.data.cpu().numpy()),
                float(loss_recon.data.cpu().numpy()),
                float(loss_kl.data.cpu().numpy())),
            end='')

    if FLAGS.save != '':
        torch.save(model, FLAGS.save)
    
    losses = [model.elbo(Variable(images).cuda() if CUDA else Variable(images))
     for i, (images, ) in enumerate(dev_loader)]
    
    loss_recon, loss_kl = np.array([[loss_recon.mean().data.cpu().numpy(), loss_kl.mean().data.cpu().numpy()] 
                  for loss_recon, loss_kl in losses]).mean(0)
    
    print('\r epoch: {:4}/{:4}, it: {:4}/{:4}: ELBO: {:.4f}, recon: {:.4f}, KL: {:.4f}'.format(
        j, FLAGS.epochs, i, len(train_loader), loss_recon + loss_kl, loss_recon, loss_kl))

In [None]:
losses = [model.elbo(Variable(images).cuda() if CUDA else Variable(images))
 for i, (images, ) in enumerate(test_loader)]

ll = np.array([model.log_likelihood(Variable(images).cuda() if CUDA else Variable(images), n=500).data.cpu().numpy()
 for i, (images, ) in enumerate(test_loader)]).mean()

loss_recon, loss_kl = np.array([[loss_recon.mean().data.cpu().numpy(), loss_kl.mean().data.cpu().numpy()] 
              for loss_recon, loss_kl in losses]).mean(0)
print("TEST VAE CUBES results for {} with {}  ---  {} epochs ".format(FLAGS.latent, FLAGS.decoder, FLAGS.epochs))
print('ELBO: {:.4f}, recon: {:.4f}, KL: {:.4f}, LL: {:.4f}'.format(
    loss_recon + loss_kl, loss_recon, loss_kl, ll))

In [None]:
img = Variable(next(iter(train_loader))[0][0:1])
rec_img = model(img.cuda())

plt.imshow(img.data.cpu().numpy()[0].transpose(1, 2, 0))
plt.show()
plt.imshow(F.sigmoid(rec_img).data.cpu().numpy()[0,0].transpose(1, 2, 0))
plt.show()

## Continuity checks

In [None]:
# degrees = 3
# matrix_dims = (degrees + 1) ** 2
# in_dims = 2
# item_rep = nn.Parameter(torch.randn((matrix_dims, in_dims)))

In [None]:
# N = 1000
# x0 = np.random.normal(0,0.001,(N+1,3))
# print(x0.shape)
# x0 = np.cumsum(np.cumsum(x0,-2), -2)
# print(x0.shape)
# x0 = torch.tensor(x0, dtype = torch.float32)
# group_el = SO3reparameterize._expmap_rodrigues(x0)
# angles = group_matrix_to_eazyz(group_el)


# plt.plot(angles.detach().numpy()[:,0], "o", color = "b")
# plt.show()
# plt.plot(angles.detach().numpy()[:,1], "o", color = "b")
# plt.show()
# plt.plot(angles.detach().numpy()[:,2], "o", color = "b")
# plt.show()

In [None]:

# n = angles.size(0)
# item_expanded = item_rep.expand(n, -1, -1)
# item = block_wigner_matrix_multiply(angles, item_expanded, degrees) \
#             .view(-1, matrix_dims * in_dims)

# i = np.random.randint(item.shape[1])
   
# plt.plot(item.detach().numpy()[:,i], "o", color = "b")
# plt.show()
