# VAE6
## Changes to VAE3
* Add robustness to blurriness
* Improved beta-VAE

In [None]:
import os
import sys

import numpy as np
import torch
import torch.nn as nn
import pytorch3d
from skimage.filters import gaussian

import trimesh

from tqdm import tqdm

from vae_3d import VAE3D
from gaussian_smoothing import GaussianSmoothing

# Check whether GPU is available.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

torch.manual_seed(1337)
np.random.seed(1337)

In [None]:
import difx
voxelize = difx.Voxelizer(image_size=32, sigma=4e-3)

In [None]:
from pytorch3d.datasets import ShapeNetCore

SYNSET_CHAIR = '03001627'
SYNSET_JAR = '03593526'

SHAPENET_PATH = '/home/ubuntu/voxel-autoencoder/shapenet/ShapeNetCore.v2'
R2N2_PATH = '/home/ubuntu/voxel-autoencoder/shapenet/ShapeNetVox32'

shapenet_dataset = ShapeNetCore(SHAPENET_PATH, synsets=[SYNSET_CHAIR], version=2, load_textures=True)

len(shapenet_dataset)

In [None]:
class Vertex2Face(nn.Module):
    def __init__(self):
        super(Vertex2Face, self).__init__()
    
    def forward(self, vert, idx):
        """
        :param vert: real tensor [batch size, num vertices, 3]
        :param idx: int tensor [batch size, num faces, 3]
        :return: real tensor [batch size, num_faces, 3, 3]
        """
        assert (vert.ndimension() == 3)
        assert (idx.ndimension() == 3)
        assert (vert.shape[0] == idx.shape[0])
        assert (vert.shape[2] == 3)
        assert (idx.shape[2] == 3)

        bs, nv = vert.shape[:2]
        bs, nf = idx.shape[:2]
        device = vert.device
        idx = idx + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
        vert = vert.reshape((bs * nv, 3))
        # pytorch only supports long and byte tensors for indexing
        return vert[idx.long()]

class Mesh2SP(nn.Module):
    def __init__(self, pretrained_path, basis_path):
        super(Mesh2SP, self).__init__()
        
        self.vert2face = Vertex2Face()
        
        self.voxelize = difx.Voxelizer(image_size=32, sigma=4e-3)
        
        self.net = VAE3D(32).cuda().double().eval()
        self.net.load_state_dict(torch.load(pretrained_path))
        
        self.basis = torch.load(basis_path).cuda().double()
    
    def forward(self, vert, idx):
        faces = self.vert2face(vert, idx)
        voxels = self.voxelize(faces).unsqueeze(0)
        latent, _ = self.net.encoder(voxels)
        semantics = torch.matmul(latent, self.basis)
        return semantics, voxels

In [None]:
mesh2sp = Mesh2SP('outputs_vae6/model_500_best.pth', 'outputs_vae6/model_basis.pth')

In [None]:
import random
mesh = random.choice(shapenet_dataset)
vert, idx = mesh['verts'].unsqueeze(0).cuda().double().detach(), mesh['faces'].unsqueeze(0).cuda().detach()

semantics, voxels = mesh2sp(vert, idx)

In [None]:
mesh_tri = trimesh.Trimesh(mesh['verts'], mesh['faces'])
mesh_tri.show()

In [None]:
print(semantics.shape)
target_semantics = semantics.clone().detach()
target_semantics[0,0,0] += -1
target_semantics

In [None]:
vert.requires_grad = True
optimizer = torch.optim.Adam([vert], lr=1e-4)

loss = np.infty
while loss > 1e-2:
    optimizer.zero_grad
    semantics, voxels = mesh2sp(vert, idx)
    loss = torch.sum((semantics - target_semantics) ** 2)
    print(loss.item())
    
    loss.backward()
    optimizer.step()

In [None]:
mesh_tri = trimesh.Trimesh(vert[0,:,:].detach().cpu(), mesh['faces'])
mesh_tri.show()