## setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import sys
import os
sys.path.append(os.path.abspath(''))

import torch
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

print(f"device:{device}")

## set train parameters

In [None]:
NUM_VIEWS = 15
Z_DIM = 100
Lr = 1e-4
MAX_ITER = 100000
SAVE_ITER = 1000

## create renderer

In [None]:
from pytorch3d.utils import ico_sphere

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes, save_obj, IO

# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    OpenGLPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    SoftSilhouetteShader,
    SoftPhongShader,
    TexturesVertex
)

from plot_image_grid import image_grid

In [None]:
class Render_silhouette():
    def __init__(self, num_views, device, image_size=64):
        # the number of different viewpoints from which we want to render the mesh.
        self.num_views = num_views
        self.device = device

        # # Get a batch of viewing angles. 
        elev = torch.linspace(0, 360, self.num_views)
        azim = torch.linspace(-180, 180, self.num_views)

        # Place a point light in front of the object. As mentioned above, the front of 
        self.lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

        # Initialize an OpenGL perspective camera that represents a batch of different 
        # viewing angles. All the cameras helper methods support mixed type inputs and 
        # broadcasting. So we can view the camera from the a distance of dist=2.7, and 
        # then specify elevation and azimuth angles for each viewpoint as tensors. 
        R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim) 
        self.cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)

        # We arbitrarily choose one particular view that will be used to visualize 
        # results
        camera = OpenGLPerspectiveCameras(device=device, R=R[None, 1, ...], 
                                          T=T[None, 1, ...]) 

        # Rasterization settings for silhouette rendering  
        sigma = 1e-4
        raster_settings_silhouette = RasterizationSettings(
            image_size=image_size, 
            blur_radius=np.log(1. / 1e-4 - 1.)*sigma, 
            faces_per_pixel=50,
            perspective_correct=False  ## avoid nan in backprop
        )

        # Silhouette renderer 
        self.renderer_silhouette = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=camera, 
                raster_settings=raster_settings_silhouette
            ),
            shader=SoftSilhouetteShader()
        )
        
    def render_silhouette(self, mesh):
        # We scale normalize and center the target mesh to fit in a sphere of radius 1 
        # centered at (0,0,0). (scale, center) will be used to bring the predicted mesh 
        # to its original center and scale.  Note that normalizing the target mesh, 
        # speeds up the optimization but is not necessary!
        verts = mesh.verts_packed()
        N = verts.shape[0]
        center = verts.mean(0)
        scale = max((verts - center).abs().max(0)[0])
        mesh.offset_verts_(-center)
        mesh.scale_verts_((1.0 / float(scale)));
        
        # Create a batch of meshes by repeating the cow mesh and associated textures. 
        # Meshes has a useful `extend` method which allows us do this very easily. 
        # This also extends the textures. 
        meshes = mesh.extend(self.num_views)

        # Render silhouette images.  The 3rd channel of the rendering output is 
        # the alpha/silhouette channel
        silhouette_images = self.renderer_silhouette(meshes, cameras=self.cameras, lights=self.lights)[..., 3]
        return torch.clamp(silhouette_images.unsqueeze(dim=0), min=0.001, max=0.999) # avoid nan in backprop

In [None]:
## TEST rendering
io = IO()
mesh = io.load_mesh('modelnet/off/chair_0001.off', device=device, load_textures=True)

render_s = Render_silhouette(15, device, image_size=64)
res = render_s.render_silhouette(mesh)

print(mesh.textures)

plt.figure(figsize=(10,10))
for x in range(3*5):
    plt.subplot(3, 5, x+1)
    plt.imshow(res[0, x].cpu().detach().numpy())
    plt.axis('off')
plt.savefig('result/modelnet/sample.jpg')

## create trainable autoencoder network

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VarAutoEncoder(nn.Module):
    def __init__(self, z_dim, num_views, render_s):
        super(VarAutoEncoder, self).__init__()
        
        self.num_views = num_views
        self.render_s = render_s
        
        self.enc = nn.Sequential(
            nn.Conv2d(num_views, 64, 3, 1, padding='same'),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(64, 128, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(128, 128, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(128, 128, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(128, 64, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Flatten()
        )
        
        self.encmean = nn.Linear(64*3*3, z_dim)
        self.encvar = nn.Linear(64*3*3, z_dim)
        
        self.dec = nn.Sequential(
            nn.Linear(z_dim, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 162*3),
            nn.Sigmoid()
        )
        
    def _encoder(self, x):
        x = self.enc(x)
        mean = self.encmean(x)
        var = F.softplus(self.encvar(x))
        return mean, var
    
    def _sample_z(self, mean, var):
        epsilon = torch.randn(mean.shape).to(device)
        return mean + torch.sqrt(var) * epsilon
 
    def _decoder(self, z):
        new_verts = self.dec(z)
        new_verts = torch.reshape(new_verts, (new_verts.shape[1]//3, 3))
                
        mesh = ico_sphere(2, device)
        verts_size = mesh.verts_packed().shape[0]
        mesh._verts_packed = new_verts
        return self.render_s.render_silhouette(mesh)

    def forward(self, x):
        mean, var = self._encoder(x)
        z = self._sample_z(mean, var)
        x = self._decoder(z)
        return x  # return x, z
    
    def loss(self, x):
        mean, var = self._encoder(x)
        KL = -0.5 * torch.mean(torch.sum(1 + torch.log(var) - mean**2 - var))
        z = self._sample_z(mean, var)
        y = self._decoder(z)
        
        reconstruction = torch.mean(torch.sum(x * torch.log(y) + (1 - x) * torch.log(1 - y)))
        lower_bound = [-KL, reconstruction]                                      
        return -sum(lower_bound)

## train

In [None]:
from torch import optim

render_s = Render_silhouette(NUM_VIEWS, device, image_size=64)
model = VarAutoEncoder(Z_DIM, NUM_VIEWS, render_s).to(device)

optimizer = optim.Adam(model.parameters(), lr=Lr)
model.train()
losses = []

import random
import glob
file_pathes = glob.glob('modelnet/off/*.off')

from torch.autograd import detect_anomaly
io = IO()

for i in range(MAX_ITER):
#     with detect_anomaly():
    obj_filename = random.choice(file_pathes)
    print(obj_filename, end=', ')
    mesh = io.load_mesh(obj_filename, device=device, load_textures=True)
    x = render_s.render_silhouette(mesh)

    model.zero_grad()
    y = model(x)
    loss = model.loss(x)
    loss.backward()
    optimizer.step()
    losses.append(loss.cpu().detach().numpy())
    print(f'iter:{(i+1)}/{MAX_ITER}, loss:{losses[-1]}')

    if (i+1) % SAVE_ITER == 0:
        plt.figure(figsize=(10,10))
        plt.subplots_adjust(wspace=0, hspace=0)
        for idx in range(3*5):
            plt.subplot(6, 5, idx+1)
            plt.imshow(x[0, idx].cpu().detach().numpy())
            plt.axis('off')
        for idx in range(3*5):
            plt.subplot(6, 5, idx+1+(3*5))
            plt.imshow(y[0, idx].cpu().detach().numpy())
            plt.axis('off')
        plt.savefig('result/modelnet/result_{:07d}.jpg'.format(i+1))

In [None]:
plt.figure(figsize=(10,10))
plt.plot(losses)
plt.savefig('result/modelnet/train_loss.png')

## sampling and evaluate

In [None]:
for i in range(50):
    z = torch.randn(1, Z_DIM).to(device)
    y = model._decoder(z)

    plt.figure(figsize=(10,5))
    for x in range(3*5):
        plt.subplot(3, 5, x+1)
        plt.imshow(y[0, x].cpu().detach().numpy())
        plt.axis('off')
    plt.savefig('result/modelnet/sample_{:07d}.jpg'.format(i+1))