In [None]:
import os
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
import numpy as np
import matplotlib.pyplot as plt
import argparse

from vae import VAE
from reparameterize import Nreparameterize, SO3reparameterize, N0reparameterize

from torch.utils.data.dataset import Dataset

CUDA = torch.cuda.is_available()

In [None]:
from glob import glob
from torch.utils.data import Dataset, DataLoader
import os.path
from PIL import Image
from lie_learn.groups.SO3 import change_coordinates as SO3_coordinates
from tensorboardX import SummaryWriter


import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 

from utils import MLP, random_split
from lie_tools import group_matrix_to_eazyz, block_wigner_matrix_multiply

In [None]:
parser = argparse.ArgumentParser()

parser.add_argument('--latent', '-z', type=str, default='so3', help='')
parser.add_argument('--epochs', '-e', type=int, default=10, help='')
parser.add_argument('--load', '-l', type=str, default='', help='')
parser.add_argument('--save', '-s', type=str, default='', help='')
parser.add_argument('--batch_dim', '-b', type=int, default=64, help='')

FLAGS, unparsed = parser.parse_known_args()

In [None]:
class DeconvNet(nn.Sequential):
    """1x1 to 32x32 deconvolutional stack."""
    def __init__(self, in_dims, hidden_dims):
        super().__init__(
            nn.ConvTranspose2d(in_dims, hidden_dims, 4, 1, 0),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dims, hidden_dims, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dims, hidden_dims, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dims, 3, 4, 2, 1),

        )
        
class ActionNet(nn.Module):
    """Uses proper group action."""
    def __init__(self, degrees, in_dims=10, with_mlp=False):
        super().__init__()
        self.degrees = degrees
        self.in_dims = in_dims
        self.matrix_dims = (degrees + 1) ** 2
        self.item_rep = nn.Parameter(torch.randn((self.matrix_dims, in_dims)))
        self.deconv = DeconvNet(self.matrix_dims * self.in_dims, 50)
        if with_mlp:
            self.mlp = MLP(self.matrix_dims * in_dims,
                           self.matrix_dims * in_dims, 50, 3)
        else:
            self.mlp = None

    def forward(self, matrices):
        """Input dim is [batch, 3, 3]."""
        n = matrices.size(0)
        item_expanded = self.item_rep.expand(n, -1, -1)

        angles = group_matrix_to_eazyz(matrices)
        item = block_wigner_matrix_multiply(angles, item_expanded, self.degrees) \
            .view(-1, self.matrix_dims * self.in_dims)

        if self.mlp:
            item = self.mlp(item)
        out = self.deconv(item[:, :, None, None])
        return out[:, :, :, :]

In [None]:
class View(nn.Module):
    def __init__(self, *v):
        super(View, self).__init__()
        self.v = v
    
    def forward(self, x):
        return x.view(*self.v)
    
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    
    def forward(self, x):
        return x.view(x.size(0), -1)

class ConvAE(nn.Module):
    def __init__(self):
        super(ConvAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            Flatten(),
            nn.Linear(32 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 3),
            nn.ReLU()
        )

        self.decoder = ActionNet(6,10)
        
        
    def forward(self,x):
        lie_algebra_el = self.encoder(x)
        group_el = SO3reparameterize._expmap_rodrigues(lie_algebra_el)
        
        x_recon = self.decoder(group_el)
        return x_recon
    
    def recon_loss(self, x_recon, x):
        x = x.expand_as(x_recon)
        max_val = (-x_recon).clamp(min=0)
        loss = x_recon - x_recon * x + max_val + ((-max_val).exp() + (-x_recon - max_val).exp()).log() 
        
        return loss.sum(-1).sum(-1).sum(-1)

In [None]:
model = ConvAE().cuda() if CUDA else ConvAE()

if FLAGS.load != '':
    model = torch.load(FLAGS.load)
    
optimizer = torch.optim.Adam(model.parameters())

In [None]:
for j in range(FLAGS.epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images).cuda() if CUDA else Variable(images)

        optimizer.zero_grad()
        
        images_recon = model(images)
        
#         loss = (images_recon - images).pow(2).sum(-1).sum(-1).sum(-1).mean()
        loss = model.recon_loss(images_recon, images).mean()
        loss.backward()
        optimizer.step()
        print('\r epoch: {:4}/{:4}, it: {:4}/{:4}: loss: {:.4f},'.format(
                j, FLAGS.epochs, i, len(train_loader),
                float(loss.data.cpu().numpy())),
            end='')

    if FLAGS.save != '':
        torch.save(model, FLAGS.save)
    
    images = Variable(dev_data).cuda() if CUDA else Variable(dev_data)
    images_recon = model(images)
    
#     loss = (images_recon - images).pow(2).sum(-1).sum(-1).sum(-1).mean()
    loss = model.recon_loss(images_recon, images).mean()
    
    print('\r epoch: {:4}/{:4}, it: {:4}/{:4}: loss: {:.4f},'.format(
                j, FLAGS.epochs, i, len(train_loader),
                float(loss.data.cpu().numpy())))

In [None]:
img = Variable(next(iter(train_loader))[0][0:1])
rec_img = model(img.cuda())

plt.imshow(img.data.cpu().numpy()[0].transpose(1, 2, 0))
plt.show()
plt.imshow(F.sigmoid(rec_img).data.cpu().numpy()[0].transpose(1, 2, 0))
plt.show()