In [1]:
"""
We use FFHQ dataset as training data in this experiment
--------------------------------------------------------
1) Pre-train the facenet, disentangle nets (part encoders).

2) Freeze facenet, disentangle nets (part encoders), part-decoders, and train blending decoders.

--------------------------------------------------------
Author: Peizhi Yan
Date Updated (1st): Sep.-17-2021
Date Updated (2nd): Oct.-24-2021
"""

'\nWe use FFHQ dataset as training data in this experiment\n--------------------------------------------------------\n1) Pre-train the facenet, disentangle nets (part encoders).\n\n2) Freeze facenet, disentangle nets (part encoders), part-decoders, and train blending decoders.\n\n--------------------------------------------------------\nAuthor: Peizhi Yan\nDate Updated (1st): Sep.-17-2021\nDate Updated (2nd): Oct.-24-2021\n'

In [2]:
import sys
import gc
import os
from tqdm import tqdm
import time

import numpy as np
import pickle
import scipy.io as sio
from scipy.io import loadmat
import matplotlib.pyplot as plt
import cv2
import torchvision.transforms as transforms
from PIL import Image


# Pytorch 1.9
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions

# Pytorch3d
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    PointLights,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    TexturesVertex,
    blending
)

import open3d as o3d

# facenet-pytorch 2.5.2
from facenet_pytorch import MTCNN, InceptionResnetV1

# face-alignment 1.3.4
import face_alignment

#######################################
## Setup PyTorch
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
    print('CUDA is available. Device: ', torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print('CUDA is NOT available. Use CPU instead.')

c:\users\yanpe\appdata\local\programs\python\python36\lib\site-packages\numpy\.libs\libopenblas.TXA6YQSD3GCQQC22GEQ54J2UDCXDXHWN.gfortran-win_amd64.dll
c:\users\yanpe\appdata\local\programs\python\python36\lib\site-packages\numpy\.libs\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll
  stacklevel=1)


CUDA is available. Device:  NVIDIA GeForce RTX 3070


In [3]:
img_path = '../datasets/FFHQ/images224x224/{}.png'
shape_path = '../datasets/FFHQ/raw_bfm_shape/{}.npy'
albedo_path = '../datasets/FFHQ/raw_bfm_color/{}.npy'

img_indices = []
for fname in os.listdir('../datasets/FFHQ/images224x224/'):
    if fname.endswith('.png'):
        img_indices.append(fname[:-4])

print(len(img_indices))

10


In [5]:
"""
 1: face skin
 2: eye brows
 3: eyes
 4: nose
 5: upper lip
 6: lower lip
"""
label_map = {
    'skin': 1,
    'eye_brow': 2,
    'eye': 3,
    'nose': 4,
    'u_lip': 5,
    'l_lip': 6
}

## Load the face parsing labels (per-vertex)
vert_labels = np.load('../BFM/bfm_vertex_labels.npy')
print(set(vert_labels))


## Load the BFM model
import pickle
with open('../BFM/bfm09.pkl', 'rb') as f:
    bfm = pickle.load(f)
print('BFM model loaded\n')

## Triangal Facets
Faces = bfm['tri'] - 1 ## -1 is critical !!!

# find the vertices of part
part_vertices = {
    'S_overall':[],
    'S_eyebrows':[],
    'S_eyes':[],
    'S_llip':[],
    'S_nose':[],
    'S_ulip':[]
}
for idx in range(len(vert_labels)):
    part_vertices['S_overall'].append(idx)
    if vert_labels[idx] in [label_map['eye_brow']]:
        part_vertices['S_eyebrows'].append(idx)
    if vert_labels[idx] in [label_map['eye']]:
        part_vertices['S_eyes'].append(idx)
    if vert_labels[idx] in [label_map['l_lip']]:
        part_vertices['S_llip'].append(idx)
    if vert_labels[idx] in [label_map['u_lip']]:
        part_vertices['S_ulip'].append(idx)
    if vert_labels[idx] in [label_map['nose']]:
        part_vertices['S_nose'].append(idx)


for key in part_vertices:
    part_vertices[key] = np.array(part_vertices[key])
    print(key, ' n_vert: ', len(part_vertices[key]))

{1, 2, 3, 4, 5, 6}
BFM model loaded

S_overall  n_vert:  35709
S_eyebrows  n_vert:  444
S_eyes  n_vert:  586
S_llip  n_vert:  309
S_nose  n_vert:  1711
S_ulip  n_vert:  576


In [6]:
#########################
## Latent dimensions
latent_dims = {}
latent_dims['S_overall'] = 30
latent_dims['S_eyebrows'] = 10
latent_dims['S_eyes'] = 10
latent_dims['S_llip'] = 10
latent_dims['S_ulip'] = 10
latent_dims['S_nose'] = 10


latent_coeff_from_to = {} # later we use the pre-generated indices to merge the part latents 
_from_ = 0
for key in latent_dims:    
    latent_coeff_from_to[key] = (_from_, _from_ + latent_dims[key])
    _from_ += latent_dims[key]
print(latent_coeff_from_to)


{'S_overall': (0, 30), 'S_eyebrows': (30, 40), 'S_eyes': (40, 50), 'S_llip': (50, 60), 'S_ulip': (60, 70), 'S_nose': (70, 80)}


In [8]:
sys.path.append('../')
from models.Modules import Encoder, Decoder, VAE, VariationalDisentangleModule,\
OffsetRegressorA, OffsetRegressorB, OffsetRegressor
    
#################
## Part Decoders
MODEL_PATH = '../saved_models/part_decoders/{}'
part_decoders = {}
for key in part_vertices:
    part_decoders[key] = Decoder(latent_dim=latent_dims[key], n_vert=len(part_vertices[key])).to(device)
    # Load pre-trained parameters
    if os.path.exists(MODEL_PATH.format(key)):
        part_decoders[key].load_state_dict(torch.load(MODEL_PATH.format(key)))
        print('Decoder {} loaded'.format(key))
    # freeze the network parameters
    for param in part_decoders[key].parameters():
        param.requires_grad = False
    
print('done')

Decoder S_overall loaded
Decoder S_eyebrows loaded
Decoder S_eyes loaded
Decoder S_llip loaded
Decoder S_nose loaded
Decoder S_ulip loaded
done


In [9]:

## Load facenet
facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
for param in facenet.parameters():
    param.requires_grad = False
    
## load our retrained facenet parameters
facenet.load_state_dict(torch.load('../saved_models/facenet'))

print('model on CUDA: ', next(facenet.parameters()).is_cuda) # True


model on CUDA:  True


In [10]:

disentangleNets = {}
for key in latent_dims:
    disentangleNets[key] = VariationalDisentangleModule(512, latent_dims[key]).to(device)
    print(key, 'model on CUDA: ', next(disentangleNets[key].parameters()).is_cuda) # True
    
print('-----------------')
    
## Load Disentangle Net params
disentangle_nets_path = '../saved_models/disentangle_nets/{}'
for key in latent_dims:
    try:
        disentangleNets[key].load_state_dict(torch.load(disentangle_nets_path.format(key)))
        print(key, ' loaded')
    except:
        print(key, ' NOT loaded')
        
    

S_overall model on CUDA:  True
S_eyebrows model on CUDA:  True
S_eyes model on CUDA:  True
S_llip model on CUDA:  True
S_ulip model on CUDA:  True
S_nose model on CUDA:  True
-----------------
S_overall  loaded
S_eyebrows  loaded
S_eyes  loaded
S_llip  loaded
S_ulip  loaded
S_nose  loaded


In [11]:

def next_batch(idx, batch_size, img_path, shape_path, albedo_path, img_indices):
    """
        return: images, shapes, albedos, new index
    """
    if idx + batch_size > len(img_indices):
        batch_size = len(img_indices) - idx
                
    batch_x = torch.zeros([batch_size, 3, 224, 224], dtype=torch.float32).to(device)
    batch_s = torch.zeros([batch_size, 35709, 3], dtype=torch.float32).to(device)
    batch_t = torch.zeros([batch_size, 35709, 3], dtype=torch.float32).to(device)

    i = idx
    counter = 0
    while i < idx + batch_size:
        file_index = img_indices[i]
        
        img = Image.open(img_path.format(file_index))
        img = img.resize((224,224))
        img = np.asarray(img, dtype=np.float32)
        img = torch.from_numpy(img)
        img = img.permute(2,0,1)
        
        batch_x[counter] = img
        batch_s[counter] = torch.from_numpy(np.reshape(np.load(shape_path.format(file_index)), [35709, 3]))
        batch_t[counter] = torch.from_numpy(np.reshape(np.load(albedo_path.format(file_index))/255., [35709, 3]))
        
        i += 1
        counter += 1
        
    return batch_x, batch_s, batch_t, i


def next_batch_cached(idx, batch_size, encoding_path, shape_path, albedo_path, img_indices):
    """ 
    Use cached facenet encodings 
        return: img encodings, shapes, albedos, new index
    """
    if idx + batch_size > len(img_indices):
        batch_size = len(img_indices) - idx
                
    batch_e = torch.zeros([batch_size, 512], dtype=torch.float32).to(device)
    batch_s = torch.zeros([batch_size, 35709, 3], dtype=torch.float32).to(device)
    batch_t = torch.zeros([batch_size, 35709, 3], dtype=torch.float32).to(device)

    i = idx
    counter = 0
    while i < idx + batch_size:
        file_index = img_indices[i]
        
        batch_e[counter] = torch.from_numpy(np.reshape(np.load(encoding_path.format(file_index)), [512]))
        batch_s[counter] = torch.from_numpy(np.reshape(np.load(shape_path.format(file_index)), [35709, 3]))
        batch_t[counter] = torch.from_numpy(np.reshape(np.load(albedo_path.format(file_index))/255., [35709, 3]))
        
        i += 1
        counter += 1
        
    return batch_e, batch_s, batch_t, i




In [12]:
def compute_shape_loss(pred_shape, targ_shape):
    batch_size = pred_shape.shape[0]
    loss = torch.square(pred_shape - targ_shape)
    loss = loss.sum(dim=[0, 1, 2]) / batch_size
    #loss = torch.mean(loss)
    return loss

def standardize_part_shape(part_shapes):
    # part_shapes: [n, n_vert, 3]  n -- batch size
    y_max, _ = torch.max(part_shapes[...,1], dim=1) # [n, 1]
    y_min, _ = torch.min(part_shapes[...,1], dim=1) # [n, 1]
    y_center = (y_max + y_min) / 2 # [n, 1]
    z_max, _ = torch.max(part_shapes[...,2], dim=1) # [n, 1]
    z_min, _ = torch.min(part_shapes[...,2], dim=1) # [n, 1]
    z_center = (z_max + z_min) / 2 # [n, 1]
    batch_size = part_shapes.shape[0]
    for i in range(batch_size):
        part_shapes[i, :, 1] -= y_center[i]
        part_shapes[i, :, 2] -= z_center[i]
    return part_shapes


In [None]:
from torch.utils.tensorboard import SummaryWriter
tb_writer = SummaryWriter()


## Train disentangle layers (part encoders)

In [None]:
#######################################
## Disentangle Layers Pre-Training
#######################################

EPOCHS = 100
#BATCH_SIZE = 128
BATCH_SIZE = 8

caching_path = './cache/facenet_encoding/{}.npy'


####################
## Some switches
SW = {0: False, 1: True} # Switch
train_or_not = {} # indicate whether or not to train a part module
train_or_not['Facenet'] =      SW[0]
use_cached_facenet_encodings = SW[1]
train_or_not['disentangle']  = SW[1]
train_or_not['part_decoders'] = SW[0] # please do not train part-decoders!!
train_or_not['S_overall'] =     SW[1]
train_or_not['S_eyebrows'] = SW[1]
train_or_not['S_eyes'] =     SW[1]
train_or_not['S_llip'] =     SW[1]
train_or_not['S_ulip'] =     SW[1]
train_or_not['S_nose'] =     SW[1]

parameters = [] ## the parameters to be optimized

################################
## Disentangle Layer parameters
for key in latent_dims:
    if train_or_not['disentangle']:
        if train_or_not[key] == False:
            continue
        if parameters is None:
            parameters = list(disentangleNets[key].parameters())
        else:
            parameters += list(disentangleNets[key].parameters())

################################
## Unfreeze FaceNet parameters
if train_or_not['Facenet']:
    facenet.train()
    for param in facenet.parameters():
        param.requires_grad = True
    parameters += list(facenet.parameters())   
else:
    facenet.eval()
    for param in facenet.parameters():
        param.requires_grad = False

################################
## Part Decoders
if train_or_not['part_decoders']:
    for key in latent_dims:
        if train_or_not[key] == False:
            continue
        for param in part_decoders[key].parameters():
            param.requires_grad = True
        parameters += list(part_decoders[key].parameters())    

###############
## Optimizers
lr = 1e-4 # initial learning rate
#lr = 1e-5
optimizers = {}
for key in latent_dims:
    disentangle = disentangleNets[key]
    optimizers[key] = torch.optim.Adam(disentangle.parameters(), lr=lr)
    
#####################
## Gaussian Sampler
normal = torch.distributions.Normal(0, 1) # a normal distribution with mean = 0, std = 1
if device.type == 'cuda':
    normal.loc = normal.loc.cuda() # use CUDA GPU for sampling
    normal.scale = normal.scale.cuda()


#############
## Training
iter_ = 0
for epoch in range(EPOCHS):
    
    idx = 100 # first 100 images as validation data
    with tqdm(total=len(img_indices[100:])) as pbar:
        while idx < len(img_indices[100:]):


            ######################
            # Prepare batch data
            if use_cached_facenet_encodings:
                img_encodings, batch_s, batch_t, idx = next_batch_cached(idx, BATCH_SIZE, caching_path, 
                                                                   shape_path, albedo_path, img_indices)                
            else:
                batch_x, batch_s, batch_t, idx = next_batch(idx, BATCH_SIZE, img_path, shape_path, albedo_path, img_indices)
                if batch_x is None:
                    continue
                img_encodings = facenet(batch_x)
                
            ##################
            # Train each part
            for key in latent_dims:
                if train_or_not[key] == False:
                    continue

                ##################
                # Clean gradient
                optimizers[key].zero_grad()
                    
                part_mu, part_sigma = disentangleNets[key](img_encodings)
                part_latents = part_mu + part_sigma * normal.sample(part_mu.size())
                kl_loss = torch.mean(torch.sum(-0.5 * (1 + part_sigma - part_mu**2 - torch.exp(part_sigma)), dim=1))
                preds = part_decoders[key](part_latents)

                batch_shape_part = batch_s[:,part_vertices[key],:]
                if key != 'S_overall':
                    ## Standardize the part shapes
                    batch_shape_part = standardize_part_shape(batch_shape_part)
                part_shape_loss = compute_shape_loss(preds, batch_shape_part)
                    
                loss = part_shape_loss + 1e3*kl_loss
                                            
                ######
                # BP
                loss.backward()

                ############
                # Optimize
                optimizers[key].step()
                
                ################
                ## TensorBoard
                if iter_ % 100 == 0:
                    tb_writer.add_scalar('{}/res_loss'.format(key), part_shape_loss.data.detach().cpu().numpy(), iter_)
                    tb_writer.add_scalar('{}/kl_loss'.format(key), kl_loss.data.detach().cpu().numpy(), iter_)
                    tb_writer.add_scalar('{}/loss'.format(key), loss.data.detach().cpu().numpy(), iter_)

            
            
            pbar.update(BATCH_SIZE)
            iter_ += 1

    print('Epoch: ', epoch)

    
    

## Train facenet + disentangle layers (part encoders)

In [None]:
#######################################
## Fine-tuning
#######################################

EPOCHS = 100
#BATCH_SIZE = 128
BATCH_SIZE = 8

caching_path = './cache/facenet_encoding/{}.npy'


####################
## Some switches
SW = {0: False, 1: True} # Switch
train_or_not = {} # indicate whether or not to train a part module
train_or_not['Facenet'] =      SW[1]
use_cached_facenet_encodings = SW[0]
train_or_not['disentangle']  = SW[1]
train_or_not['part_decoders'] = SW[0] # please do not train part-decoders!!
train_or_not['S_overall'] =     SW[1]
train_or_not['S_eyebrows'] = SW[1]
train_or_not['S_eyes'] =     SW[1]
train_or_not['S_llip'] =     SW[1]
train_or_not['S_ulip'] =     SW[1]
train_or_not['S_nose'] =     SW[1]

parameters = [] ## the parameters to be optimized

################################
## Disentangle Layer parameters
for key in latent_dims:
    if train_or_not['disentangle']:
        if train_or_not[key] == False:
            continue
        if parameters is None:
            parameters = list(disentangleNets[key].parameters())
        else:
            parameters += list(disentangleNets[key].parameters())

################################
## Unfreeze FaceNet parameters
if train_or_not['Facenet']:
    facenet.train()
    for param in facenet.parameters():
        param.requires_grad = True
    parameters += list(facenet.parameters())   
else:
    facenet.eval()
    for param in facenet.parameters():
        param.requires_grad = False

################################
## Part Decoders
if train_or_not['part_decoders']:
    for key in latent_dims:
        if train_or_not[key] == False:
            continue
        for param in part_decoders[key].parameters():
            param.requires_grad = True
        parameters += list(part_decoders[key].parameters())    

###############
## Optimizers
#lr = 1e-4 # initial learning rate
lr = 1e-5
optimizers = {}
for key in latent_dims:
    disentangle = disentangleNets[key]
    optimizers[key] = torch.optim.Adam(disentangle.parameters(), lr=lr)
    
#####################
## Gaussian Sampler
normal = torch.distributions.Normal(0, 1) # a normal distribution with mean = 0, std = 1
if device.type == 'cuda':
    normal.loc = normal.loc.cuda() # use CUDA GPU for sampling
    normal.scale = normal.scale.cuda()


#############
## Training
iter_ = 0
for epoch in range(EPOCHS):
    
    idx = 100 # first 100 images as validation data
    with tqdm(total=len(img_indices[100:])) as pbar:
        while idx < len(img_indices[100:]):


            ######################
            # Prepare batch data
            if use_cached_facenet_encodings:
                img_encodings, batch_s, batch_t, idx = next_batch_cached(idx, BATCH_SIZE, caching_path, 
                                                                   shape_path, albedo_path, img_indices)                
            else:
                batch_x, batch_s, batch_t, idx = next_batch(idx, BATCH_SIZE, img_path, shape_path, albedo_path, img_indices)
                if batch_x is None:
                    continue
                img_encodings = facenet(batch_x)
                
            ##################
            # Train each part
            for key in latent_dims:
                if train_or_not[key] == False:
                    continue

                ##################
                # Clean gradient
                optimizers[key].zero_grad()
                    
                part_mu, part_sigma = disentangleNets[key](img_encodings)
                part_latents = part_mu + part_sigma * normal.sample(part_mu.size())
                kl_loss = torch.mean(torch.sum(-0.5 * (1 + part_sigma - part_mu**2 - torch.exp(part_sigma)), dim=1))
                preds = part_decoders[key](part_latents)

                batch_shape_part = batch_s[:,part_vertices[key],:]
                if key != 'S_overall':
                    ## Standardize the part shapes
                    batch_shape_part = standardize_part_shape(batch_shape_part)
                part_shape_loss = compute_shape_loss(preds, batch_shape_part)
                    
                loss = part_shape_loss + 1e3*kl_loss
                                            
                ######
                # BP
                loss.backward()

                ############
                # Optimize
                optimizers[key].step()
                
                ################
                ## TensorBoard
                if iter_ % 100 == 0:
                    tb_writer.add_scalar('{}/res_loss'.format(key), part_shape_loss.data.detach().cpu().numpy(), iter_)
                    tb_writer.add_scalar('{}/kl_loss'.format(key), kl_loss.data.detach().cpu().numpy(), iter_)
                    tb_writer.add_scalar('{}/loss'.format(key), loss.data.detach().cpu().numpy(), iter_)

            
            
            pbar.update(BATCH_SIZE)
            iter_ += 1

    print('Epoch: ', epoch)

    
    

In [31]:
## Save all the networks 

SAVE = False

if SAVE:
    model_save_path = '../saved_models/disentangle_nets/{}'
    for key in latent_dims:
        torch.save(disentangleNets[key].state_dict(), model_save_path.format(key))
        print(key, ' saved')
        
    model_save_path = '../saved_models/facenet'
    for key in latent_dims:
        torch.save(facenet.state_dict(), model_save_path.format(key))
        print('facenet saved')



    


## Train offset regressor

In [None]:
##############################
## Offset Prediction Module
##############################
offsetKeys = {'S_eyebrows': 0, 
              'S_eyes': 1, 
              'S_llip': 2, 
              'S_ulip': 3, 
              'S_nose': 4}

offsetRegressorA = OffsetRegressorA().to(device)

offserRegressorsB = {}
for key in offsetKeys:
    offserRegressorsB[key+'-y'] = OffsetRegressorB()
    offserRegressorsB[key+'-z'] = OffsetRegressorB()

offsetRegressor = OffsetRegressor(offsetRegressorA, offserRegressorsB).to(device)

## Load the pre-trained decoder weights
#offsetRegressor.load_state_dict(torch.load('../saved_models/offset_regressor'))


In [None]:
parameters = [] ## the parameters to be optimized


################################
## Offset Decoders
for param in offsetRegressor.parameters():
    param.requires_grad = True
    
#parameters += list(offsetRegressor.parameters())
parameters += list(offsetRegressor.S_eyes_y.parameters())
parameters += list(offsetRegressor.S_eyes_z.parameters())

    
###############
## Optimizer
optimizer = torch.optim.Adam(parameters, lr=1e-4) # initial learning rate
#optimizer = torch.optim.Adam(parameters, lr=1e-5) # 


#############
## Training
for epoch in range(EPOCHS):
    
    idx = 100 # first 100 images as validation data
    with tqdm(total=len(img_indices[100:])) as pbar:
        while idx < len(img_indices[100:]):
            ##################
            # Clean gradient
            optimizer.zero_grad()

            ######################
            # Prepare batch data
            img_encodings, batch_s, idx = next_batch_cached(idx, BATCH_SIZE, caching_path, 
                                                                     shape_path, img_indices)                

                
            ###########
            # Predict
            pred_offsets = offsetRegressor(img_encodings)
            
                
            ################
            # Compute loss
            batch_offsets = torch.zeros([BATCH_SIZE, 5, 2], dtype=torch.float32).to(device)
            for key in latent_dims:
                if key[0] == 'S':
                    if key != 'S_rest':
                        batch_shape_part = batch_s[:,part_vertices[key],:]
                        ## Standardize the part shapes
                        _, batch_y_offset, batch_z_offset = standardize_part_shape(batch_shape_part)
                        batch_offsets[:, offsetKeys[key], 0] = batch_y_offset
                        batch_offsets[:, offsetKeys[key], 1] = batch_z_offset            
            offset_loss = torch.sum(torch.square(pred_offsets - batch_offsets), dim=[0,1,2]) / BATCH_SIZE
            loss = offset_loss
                
            ######
            # BP
            loss.backward()

            ############
            # Optimize
            optimizer.step()
            
            pbar.update(BATCH_SIZE)

    print('Epoch: ', epoch)
    print(' -- Overall loss: ' , loss.data)





In [None]:
## Save offser regressor network

SAVE = False

if SAVE:
    model_save_path = '../saved_models/offset_regressor'
    torch.save(offsetRegressor.state_dict(), model_save_path)
    print('offset regressor saved')
            
    


## Check results

In [21]:
Faces = bfm['tri'] - 1 ## -1 is critical !!!

def render(V, T, Faces, width=512, height=512):
    ###############################
    ## Visualize the render result
    o3d_mesh = o3d.geometry.TriangleMesh()
    o3d_mesh.vertices = o3d.utility.Vector3dVector(V) # dtype vector3d (float)
    o3d_mesh.triangles = o3d.utility.Vector3iVector(Faces) # dtype vector3i (int)
    if T is not None:
        o3d_mesh.vertex_colors = o3d.utility.Vector3dVector(T) # dtype vector3i (int)
    o3d_mesh.compute_vertex_normals() # computing normal will give specular effect while rendering

    vis = o3d.visualization.Visualizer()
    vis.create_window(width=width, height=height, visible = False)
    vis.add_geometry(o3d_mesh)
    #depth = vis.capture_depth_float_buffer(True)
    image = vis.capture_screen_float_buffer(True)

    return o3d_mesh, image


In [None]:
############################
## Reconstruction Testing
############################

facenet.eval()
for param in facenet.parameters():
    param.requires_grad = False

idx = 0
#len(img_indices)
for idx in range(100):
    #idx += 1

    # Prepare batch data
    batch_x, batch_s, batch_t, _ = next_batch(idx, 1, img_path, shape_path, albedo_path, img_indices)
    if batch_x is None:
        continue
    
    # Predict
    preds = {}
    img_encodings = facenet(batch_x)
    for key in latent_dims:
        part_mu, part_sigma = disentangleNets[key](img_encodings)
        part_latents = part_mu
        preds[key] = part_decoders[key](part_latents).detach().cpu().numpy()

    pred_offsets = offsetRegressor(img_encodings).detach().cpu().numpy()

    for key in offsetKeys:
        preds[key][0,:,1:] += pred_offsets[0, offsetKeys[key], :]
        
        
    V_targ = np.zeros([35709, 3])
    V_targ = batch_s.detach().cpu().numpy()[0]
    T_targ = np.zeros([35709, 3])
    T_targ = batch_t.detach().cpu().numpy()[0]

    V_pred = np.zeros([35709, 3])
    for key in latent_dims:
        V_pred[part_vertices[key]] = preds[key][0]


    plt.figure(figsize=(10,5))
    o3d_mesh_targ, image_targ = render(V_targ, None, Faces)
    plt.subplot(1,4,1)
    plt.title('Target Mesh')
    plt.imshow(np.asarray(image_targ))

    o3d_mesh_pred, image_pred = render(V_pred, None, Faces)
    plt.subplot(1,4,2)
    plt.title('Predicted Mesh')
    plt.imshow(np.asarray(image_pred))
    
    """
    o3d_mesh_pred, image_pred = render(V_pred, T_pred, Faces)
    plt.subplot(1,4,3)
    plt.title('Predicted Mesh w/ Color')
    plt.imshow(np.asarray(image_pred))
    """
    
    plt.subplot(1,4,4)
    plt.title('{}'.format(idx))
    plt.imshow(batch_x[0].permute(1,2,0).detach().cpu().numpy()/255)
    plt.show()



In [None]:
############################
## Latent Swapping Demo
############################


def demo(idx_A, idx_B, result_combination):
    batch_x_A, batch_s_A, batch_t_A, _ = next_batch(idx_A, 1, img_path, shape_path, albedo_path, img_indices)
    batch_x_B, batch_s_B, batch_t_B, _ = next_batch(idx_B, 1, img_path, shape_path, albedo_path, img_indices)

    shape_latents_A = torch.zeros([1, shape_latent_dim], dtype=torch.float32).to(device) # as input of shape blending decoder
    shape_latents_B = torch.zeros([1, shape_latent_dim], dtype=torch.float32).to(device) # as input of shape blending decoder
    shape_latents_C = torch.zeros([1, shape_latent_dim], dtype=torch.float32).to(device) # as input of shape blending decoder
    preds_A = {}
    preds_B = {}
    preds_C = {}
    
    pointer = {
        'A': [batch_x_A, batch_s_A, batch_t_A, shape_latents_A, preds_A],
        'B': [batch_x_B, batch_s_B, batch_t_B, shape_latents_B, preds_B],
        'C': [None, None, None, shape_latents_C, preds_C]
    }

    for pnt in pointer:
        if pnt == 'C':
            continue
        [batch_x, batch_s, batch_t, shape_latents, preds] = pointer[pnt]
        for key in latent_dims:
            img_encodings = facenet(batch_x)
            part_mu, part_sigma = disentangleNets[key](img_encodings)
            part_latents = part_mu
            preds[key] = part_decoders[key](part_latents).detach().cpu().numpy()
            (_from_, _to_) = latent_coeff_from_to[key]
            if key[0] == 'S':
                shape_latents[:, _from_:_to_] = part_latents

    for key in latent_dims:
        source = result_combination[key]
        [_, _, _, shape_latents_source, preds_source] = pointer[source]
        (_from_, _to_) = latent_coeff_from_to[key]    
        if key[0] == 'S':
            shape_latents_C[:, _from_:_to_] = shape_latents_source[:, _from_:_to_]
        preds_C[key] = preds_source[key]
        
    for pnt in pointer:
        [_, _, _, shape_latents, preds] = pointer[pnt]
        pred_offsets = offsetDecoder(shape_latents).detach().cpu().numpy()
        preds['offsets'] = pred_offsets
        #for key in offsetKeys:
        #    preds[key][0,:,1:] += pred_offsets[0, offsetKeys[key], :]


    ## Visualize results
    plt.figure(figsize=(15,5))
    i = 1
    meshes = []
    for pnt in pointer:
        [_, _, _, shape_latents, preds] = pointer[pnt]
        V_pred = np.zeros([35709, 3], dtype=np.float32)
        T_pred = np.zeros([35709, 3], dtype=np.float32)
        for key in latent_dims:
            if key[0] == 'S':
                pred_part = np.copy(preds[key])
                if key in offsetKeys:
                    pred_part[0,:,1:] += preds['offsets'][0, offsetKeys[key], :]
                V_pred[part_vertices[key]] = pred_part
            else:
                T_pred[part_vertices[key]] = preds[key]
        pred_mesh, rendered_img = render(V_pred, T_pred, Faces)
        meshes.append(pred_mesh)
        plt.subplot(1,3,i); i+=1
        plt.title('{}'.format(pnt))
        plt.imshow(np.asarray(rendered_img))
    
    return meshes[0], meshes[1], meshes[2]
    
    
idx_A = 20
idx_B = 29
result_combination = {} 
result_combination['S_rest'] = 'A'
result_combination['S_eyebrows'] = 'A'
result_combination['S_eyes'] = 'B'
result_combination['S_llip'] = 'B'
result_combination['S_ulip'] = 'B'
result_combination['S_nose'] = 'B'
result_combination['T_eyebrows'] = 'A' 
result_combination['T_eyes'] = 'B'
result_combination['T_lips'] = 'B'
result_combination['T_skin'] = 'A'

mesh_A, mesh_B, mesh_C = demo(idx_A, idx_B, result_combination)

In [51]:
o3d.visualization.draw_geometries([mesh_C]) 
