# Visualize Mesh

This notebook provides a demo of how to generate SMPL mesh from 3D skeletons of MUGL

In [2]:
import numpy as np
import ipyvolume as ipv
import h5py
import os

from smplx import SMPL
import pickle

import torch 
import torch.nn as nn

from model import *
from rotation.rotation import rot6d_to_rotmat, batch_rigid_transform
from torch.autograd import Variable

In [None]:
skeleton = np.load('./files/skeleton.npy')
viewpoint = np.load('./files/viewpoint.npy')
beta = torch.tensor([-0.1474,  0.0632,  0.7616,  2.9261,  0.3609,  0.2267, -0.3828,  0.3000,
         0.5667,  0.0230])

latent_dim = 352
device = torch.device('cuda:0')
num_class = 120
main_path = "/ssd_scratch/cvit/debtanu.gupta/"

smpl = SMPL('./files',batch_size=32)

with open("./files/SMPL_NEUTRAL.pkl",'rb') as f:
    smpl_data = pickle.load(f,encoding='latin1')

parent_array = list(smpl_data['kintree_table'][0][:24])
parent_array[0] = -1

In [4]:
def infer(model, label,v=0):
    '''
        this function generates samples
        model: model object as input
        label: calss label
        v: viewpoint
    '''
    model.eval()
    # z = torch.randn(6, latent_dim).to(device).float()
    
    # y = np.repeat(np.arange(3),2)
    # y = np.arange(num_class)
    y = np.repeat(label,10)
#     rot_list = []
#     for i in y:
#         idx = np.where(label==i)
#         rot_lbl = rot[idx]
#         rand = np.random.randint(rot_lbl.shape[0])
#         rot_list.append(rot_lbl[rand])
    rot = np.repeat(viewpoint[v:v+1], 10, axis=0)
    # rot = torch.tensor(rot[:,0,:]).to(device).float()
    rot = rot[:,0,:].reshape((rot.shape[0],48,6))
    rot = rot[:,0,:]
    rot = torch.tensor(rot).to(device).float()

    label = np.zeros((y.shape[0], num_class))
    label[np.arange(y.shape[0]), y] = 1
    label = torch.tensor(label).to(device).float()
    with torch.no_grad():
        m, v = model.gaussian_parameters(model.z_pre.squeeze(0), dim=0)
        idx = torch.distributions.categorical.Categorical(model.pi).sample((label.shape[0],))
        m, v = m[idx], v[idx]
        z = model.sample_gaussian(m, v)
        
        z = torch.cat((z,label,rot), dim=1)
        z = model.latent2hidden(z)
        z = z.reshape((z.shape[0], 4, -1))
        pred = model.decoder_net(z)
        root_pred = model.root_traj(z).unsqueeze(2)

        N,T,_ = pred.shape
        alpha = rot6d_to_rotmat(torch.tensor(pred.reshape((N,T,2,144)))).view((N,T,2,24,3,3)).float()
        betas = beta.view((1,1,1,10)).repeat((N,T,2,1))
        alpha = alpha.to(betas.device)
        root_pred = root_pred.to(betas.device)

        # pred_3d1 = fkt(pred[:,:,:144].contiguous(), skeleton)
        # pred_3d2 = fkt(pred[:,:,144:].contiguous(), skeleton)
        # # pred_3d = fkt(pred, skeleton).cpu().data.numpy()
        # pred_3d1 = pred_3d1.reshape((pred_3d1.shape[0], pred_3d1.shape[1], 24,-1)).cpu().data.numpy()
        # pred_3d2 = pred_3d2.reshape((pred_3d2.shape[0], pred_3d2.shape[1], 24,-1)).cpu().data.numpy()
        # root_pred = root_pred.cpu().data.numpy()
        # pred_3d2 = pred_3d2 + root_pred

        return alpha, betas, root_pred

In [5]:
def plot(alpha, betas, root,ind, single=False, save_gif=False, save_name='example'):
    '''
        This function takes 2 persons as input and plots them
        skeleton_motion1: Skeleton of person 1
        skeleton_motion2: Skeleton of person 2
        single: True if single person class else false
        save_fig: Save gif file if True
    '''
    # print(betas.device, alpha.device, root.device)
    output1 = smpl(betas=betas[ind,:,0,:],body_pose=alpha[ind,:,0,1:], global_orient=alpha[ind,:,0,0].unsqueeze(1), pose2rot=False)
    output2 = smpl(betas=betas[ind,:,1,:],body_pose=alpha[ind,:,1,1:], global_orient=alpha[ind,:,1,0].unsqueeze(1), pose2rot=False)

    v1 = output1.vertices.data.cpu().numpy()
    skeleton_motion1 = output1.joints.data.cpu().numpy()

    v2 = output2.vertices.data.cpu().numpy() + root[ind].data.cpu().numpy()
    skeleton_motion2 = output2.joints.data.cpu().numpy() + root[ind].data.cpu().numpy()

    v1[:,:,1] *= -1
    skeleton_motion1[:,:,1] *= -1
    v2[:,:,1] *= -1
    skeleton_motion2[:,:,1] *= -1

    v1[:,:,2] *= -1
    skeleton_motion1[:,:,2] *= -1
    v2[:,:,2] *= -1
    skeleton_motion2[:,:,2] *= -1

    fig = ipv.figure(height=800,width=800)

    plot_mesh1 = ipv.plot_trisurf(x=v1[:,:,0],y=v1[:,:,1],z=v1[:,:,2],triangles=smpl_data['f'],color=[1,1,.5,1])
    plot_mesh1.material.transparent = True
    plot_mesh1.material.side = "FrontSide"
    
    a = np.arange(-np.max(skeleton_motion2[:,:,:])-1,np.max(skeleton_motion1[:,:,:])+1)
    if not single:
        a = np.arange(np.min(skeleton_motion1[:,:,:])-1,np.max(skeleton_motion1[:,:,:])+1)
        plot_mesh2 = ipv.plot_trisurf(x=v2[:,:,0],y=v2[:,:,1],z=v2[:,:,2],triangles=smpl_data['f'],color=[1,0,1,1])
        plot_mesh2.material.transparent = True
        plot_mesh2.material.side = "FrontSide"
    x,z = np.meshgrid(a,a)
    y = np.ones_like(x)*np.min(skeleton_motion1[:,:,1])
    s1 = ipv.scatter(skeleton_motion1[:,:,0],skeleton_motion1[:,:,1],skeleton_motion1[:,:,2],size=1,color='indigo',marker='sphere')
    if not single:
        s2 = ipv.scatter(skeleton_motion2[:,:,0],skeleton_motion2[:,:,1],skeleton_motion2[:,:,2],size=1,color='indigo',marker='sphere')

    if single:
        anim_list = [plot_mesh1,s1]
    else:
        anim_list = [plot_mesh1,s1,plot_mesh2,s2]
    for i,p in enumerate(parent_array): # Run loop for each bone
            if p == -1:
                continue
            b1 = ipv.plot(np.array([skeleton_motion1[:,i,0],skeleton_motion1[:,p,0]]).T,np.array([skeleton_motion1[:,i,1],skeleton_motion1[:,p,1]]).T,np.array([skeleton_motion1[:,i,2],skeleton_motion1[:,p,2]]).T ,size=10, color='darkviolet')
            
            if not single:
                b2 = ipv.plot(np.array([skeleton_motion2[:,i,0],skeleton_motion2[:,p,0]]).T,np.array([skeleton_motion2[:,i,1],skeleton_motion2[:,p,1]]).T,np.array([skeleton_motion2[:,i,2],skeleton_motion2[:,p,2]]).T ,size=10, color='darkviolet')
            anim_list.append(b1)
            if not single:
                anim_list.append(b2)
    
    ipv.plot_surface(x,y,z, color='lightgray')
    ipv.plot_wireframe(x,y,z,color='black')
    ipv.animation_control(anim_list)
    ipv.style.background_color('dark')
    ipv.style.box_off()
    ipv.style.axes_off()
    ipv.xyzlim(min(np.min(v1),np.min(v2)),max(np.max(v1),np.max(v2)))
    ipv.show()
    
    if save_gif:
        def slide(figure, framenr, fraction):
            for a in anim_list:
                if a.sequence_index == skeleton_motion1.shape[0]:
                    a.sequence_index = 0
                a.sequence_index += 1        
        ipv.movie(save_name + '.gif', slide, fps=5, frames=skeleton_motion1.shape[0])

In [None]:
model = Model(latent_dim).to(device)
model.load_state_dict(torch.load('./checkpoints/' + 'model_199.pt', map_location=torch.device('cpu')))
print('model loaded..')

In [None]:
cls_lbl = 58
alpha, betas, root_pred = infer(model, cls_lbl, v=0)

In [None]:
idx = 6
plot(alpha, betas, root_pred, idx, single=False)