## setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import time
import sys
import os
sys.path.append(os.path.abspath(''))

import torch
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

print(f"device:{device}")

## set hyperparameters

In [None]:
DATASET_ROOT = 'shapenet/shapenetcore_v2/02773838'
RESULT_DIR = 'result/shapenet2'
NUM_VIEWS = 16
Z_DIM = 100
Lr = 1e-4
MAX_ITER = 100000
SAVE_ITER = 1000

## create renderer

In [None]:
from pytorch3d.utils import ico_sphere

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes, save_obj, IO

# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    OpenGLPerspectiveCameras, 
    PointLights,
    AmbientLights,
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    SoftSilhouetteShader,
    SoftPhongShader,
    TexturesVertex
)

from plot_image_grid import image_grid

In [None]:
class Render3d():
    def __init__(self, num_views, device, image_size=64):
        # the number of different viewpoints from which we want to render the mesh.
        self.num_views = num_views
        self.device = device
        self.image_size = image_size

        # Get a batch of viewing angles. 
        elev = torch.linspace(0, 360, self.num_views)
        azim = torch.linspace(-180, 180, self.num_views)

        # Place an ambient lights 
        self.lights = AmbientLights(device=self.device)

        # Initialize an OpenGL perspective camera that represents a batch of different 
        # viewing angles. All the cameras helper methods support mixed type inputs and 
        # broadcasting. So we can view the camera from the a distance of dist=2.7, and 
        # then specify elevation and azimuth angles for each viewpoint as tensors. 
        R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim) 
        self.cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
        self.silhouette_cameras = OpenGLPerspectiveCameras(device=device, R=R[None, 1, ...], 
                                          T=T[None, 1, ...]) 

        # Rasterization settings for silhouette rendering
        sigma = 1e-4
        raster_settings_silhouette = RasterizationSettings(
            image_size=self.image_size, 
            blur_radius=np.log(1. / 1e-4 - 1.)*sigma, 
            faces_per_pixel=50,
            perspective_correct=False  ## avoid nan in backprop
        )

        # Silhouette renderer 
        self.renderer_silhouette = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=self.silhouette_cameras, 
                raster_settings=raster_settings_silhouette
            ),
            shader=SoftSilhouetteShader()
        )
        
        # Rasterization settings for normal rendering
        raster_settings = RasterizationSettings(
            image_size=self.image_size,
            blur_radius=np.log(1. / 1e-4 - 1.)*sigma, 
            faces_per_pixel=50, 
            perspective_correct=False  ## avoid nan in backprop
        )

        # normal renderer
        self.renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=self.cameras, 
                raster_settings=raster_settings
            ),
            shader=SoftPhongShader(device=self.device)
        )
    
    def render(self, mesh):
        verts = mesh.verts_packed()
        N = verts.shape[0]
        center = verts.mean(0)
        scale = max((verts - center).abs().max(0)[0])
        mesh.offset_verts_(-center)
        mesh.scale_verts_((1.0 / float(scale)));
        
        meshes = mesh.extend(self.num_views)

        rendered_images = self.renderer(meshes, cameras=self.cameras, lights=self.lights)
        return torch.clamp(rendered_images.unsqueeze(dim=0), min=0.001, max=0.999) # avoid nan in backprop
    
    def show_rendered_imgs(self, renderd_imgs, isSave=True, f_name='rgb'):
        plt.figure(figsize=(6,6))
        for x in range(4*4):
            plt.subplot(4, 4, x+1)
            plt.imshow(res[0, x].cpu().detach().numpy())
            plt.axis('off')
        
        if isSave:
            plt.savefig(f'{f_name}.jpg') 
        else:
            plt.show()
            
    def show_rendered_silhouette(self, renderd_silhouette, isSave=True, f_name='silhouette'):
        plt.figure(figsize=(6,6))
        for x in range(4*4):
            plt.subplot(4, 4, x+1)
            plt.imshow(res[0, x].cpu().detach().numpy())
            plt.axis('off')
        
        if isSave:
            plt.savefig(f'{f_name}.jpg') 
        else:
            plt.show()

In [None]:
render_s = Render3d(NUM_VIEWS, device, image_size=64)

io = IO()
try:
    mesh = io.load_mesh(f'{DATASET_ROOT}/1b84dededd445058e44a5473032f38f/models/model_normalized.obj', device=device, load_textures=True, create_texture_atlas=True)
    print(mesh.verts_packed().shape)
    print(mesh.faces_packed().shape)
    print(mesh.textures.atlas_padded().shape)

    res = render_s.render(mesh)
    print(res.shape)

    res = render_s.render(mesh)[..., 3]
    print('silhouette:', res.shape)
    render_s.show_rendered_silhouette(res, f_name=f'{RESULT_DIR}/sample_silhouette')

    res = render_s.render(mesh)[..., :3]
    print('normal_rgb', res.shape)
    render_s.show_rendered_imgs(res, f_name=f'{RESULT_DIR}/sample_render')

except IndexError:
    print("skip training process because this 3d model is broken.")

## create trainable autoencoder network

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VarAutoEncoder(nn.Module):
    def __init__(self, z_dim, num_views, render_s):
        super(VarAutoEncoder, self).__init__()
        
        self.num_views = num_views
        self.render_s = render_s
        
        self.enc_silhouette = nn.Sequential(
            nn.Conv2d(num_views, 64, 3, 1, padding='same'),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(64, 128, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(128, 128, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(128, 128, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(128, 64, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Flatten()
        )
        
        self.enc_rgb = nn.Sequential(
            nn.Conv2d(num_views*3, 64, 3, 1, padding='same'),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(64, 128, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(128, 128, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(128, 128, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(128, 64, 3, 2),
            nn.LeakyReLU(0.3),
            nn.Dropout2d(0.25),
            
            nn.Flatten()
        )
        
        
        self.encmean = nn.Linear(2*64*3*3, z_dim)
        self.encvar = nn.Linear(2*64*3*3, z_dim)
        
        self.dec_verts = nn.Sequential(
            nn.Linear(z_dim, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 162*3),
            nn.Sigmoid()
        )
        
        self.dec_textures = nn.Sequential(
            nn.Linear(z_dim, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 64*3),
            nn.LeakyReLU(0.3),
            nn.Dropout(0.25),
            
            nn.Linear(64*3, 162*3),
            nn.Sigmoid()
        )
        
    def _encoder(self, x_silhouette, x_rgb):
        x_silhouette = self.enc_silhouette(x_silhouette)
        x_rgb = self.enc_rgb(x_rgb)
        x = torch.cat((x_silhouette, x_rgb), 1)
        mean = self.encmean(x)
        var = F.softplus(self.encvar(x))
        return mean, var
    
    def _sample_z(self, mean, var):
        epsilon = torch.randn(mean.shape).to(device)
        return mean + torch.sqrt(var) * epsilon
 
    def _decoder(self, z):
        new_verts = self.dec_verts(z)
        new_verts = torch.reshape(new_verts, (new_verts.shape[1]//3, 3))
        new_textures = self.dec_textures(z)
        new_textures = torch.reshape(new_textures, (1, new_textures.shape[1]//3, 3))
                
        mesh = ico_sphere(2, device)
        verts_size = mesh.verts_packed().shape[0] 
        mesh._verts_packed = new_verts
        mesh.textures = TexturesVertex(verts_features=new_textures)
        return self.render_s.render(mesh)

    def forward(self, x):
        x_silhouette = x[..., 3]
        x_rgb = torch.reshape(torch.permute(x[..., :3], (0, 1, 4, 2, 3)), (1, self.num_views*3, x.shape[2], x.shape[3]))
        mean, var = self._encoder(x_silhouette, x_rgb)
        z = self._sample_z(mean, var)
        y = self._decoder(z)
        return y
    
    def loss(self, x):
        x_silhouette = x[..., 3]
        x_rgb = torch.reshape(torch.permute(x[..., :3], (0, 1, 4, 2, 3)), (1, self.num_views*3, x.shape[2], x.shape[3]))
        mean, var = self._encoder(x_silhouette, x_rgb)
        
        ## KL_loss
        KL = -0.5 * torch.mean(torch.sum(1 + torch.log(var) - mean**2 - var))
        
        z = self._sample_z(mean, var)
        y = self._decoder(z)
        
        ## reconstruction silhouette & rgb_texture loss
        reconstruction_silhouette = torch.mean(torch.sum(x[..., 3] * torch.log(y[..., 3]) + (1 - x[..., 3]) * torch.log(1 - y[..., 3])))
        reconstruction_rgb = torch.mean(torch.sum(x[..., :3] * torch.log(y[..., :3]) + (1 - x[..., :3]) * torch.log(1 - y[..., :3])))
        
        lower_bound = [-KL, reconstruction_silhouette, reconstruction_rgb]                                      
        return -sum(lower_bound)

## train

In [None]:
from torch import optim

render_s = Render3d(NUM_VIEWS, device, image_size=64)
model = VarAutoEncoder(Z_DIM, NUM_VIEWS, render_s).to(device)

optimizer = optim.Adam(model.parameters(), lr=Lr)
model.train()
losses = []

import random
import glob
file_pathes = glob.glob(f'{DATASET_ROOT}/*/models/model_normalized.obj')

from torch.autograd import detect_anomaly

for i in range(MAX_ITER):
#     with detect_anomaly():
    mesh = None
    while True:
        obj_filename = random.choice(file_pathes)
        print(obj_filename, end=', ')
        try:
            mesh = io.load_mesh(obj_filename, device=device, load_textures=True, create_texture_atlas=True)
            break
        except IndexError:
            print("skip train process because this 3d model seems be broken.")
            
    x = render_s.render(mesh)
    model.zero_grad()
    y = model(x)
    loss = model.loss(x)
    loss.backward()
    optimizer.step()
    losses.append(loss.cpu().detach().numpy())
    print(f'iter:{(i+1)}/{MAX_ITER}, loss:{losses[-1]}')

    if (i+1) % SAVE_ITER == 0:
        plt.figure(figsize=(12,12))
        plt.subplots_adjust(wspace=0, hspace=0)
        for idx in range(8*2):
            plt.subplot(8, 8, idx+1)
            plt.imshow(x[0, idx, ..., 3].cpu().detach().numpy())
            plt.axis('off')
        for idx in range(8*2):
            plt.subplot(8, 8, idx+1+(8*2))
            plt.imshow(x[0, idx, ..., :3].cpu().detach().numpy())            
            plt.axis('off')
        for idx in range(8*2):
            plt.subplot(8, 8, idx+1+(8*4))
            plt.imshow(y[0, idx, ..., 3].cpu().detach().numpy())
            plt.axis('off')
        for idx in range(8*2):
            plt.subplot(8, 8, idx+1+(8*6))
            plt.imshow(y[0, idx, ..., :3].cpu().detach().numpy())
            plt.axis('off')
        plt.savefig('{}/result_{:07d}.jpg'.format(RESULT_DIR, i+1))

In [None]:
plt.figure(figsize=(10,10))
plt.plot(losses)
plt.savefig(f'{RESULT_DIR}/train_loss.png')

## sampling and evaluate

In [None]:
for i in range(2):
    z = torch.randn(1, Z_DIM).to(device)
    y = model._decoder(z)
    
    plt.figure(figsize=(12, 6))
    plt.subplots_adjust(wspace=0, hspace=0)
    for idx in range(8*2):
        plt.subplot(4, 8, idx+1)
        plt.imshow(y[0, idx, ..., 3].cpu().detach().numpy())
        plt.axis('off')
    for idx in range(8*2):
        plt.subplot(4, 8, idx+1+(8*2))
        plt.imshow(y[0, idx, ..., :3].cpu().detach().numpy())
        plt.axis('off')
    plt.savefig('{}/sample_{:07d}.jpg'.format(RESULT_DIR, i+1))