# Visualization

This notebook demonstrates how to generate samples using MUGL pretrained model and plot them

In [35]:
import numpy as np
import ipyvolume as ipv
import h5py
import os

import torch 
import torch.nn as nn

from model.model import *
from rotation.rotation import rot6d_to_rotmat, batch_rigid_transform
from torch.autograd import Variable

In [36]:
# SMPL = np.load("./files/SMPLX_NEUTRAL.npz")
# parent_array = SMPL['kintree_table'][0][:24]
parent_array = np.array([-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 15, 15])
skeleton = np.load('./files/skeleton.npy')

latent_dim = 352
device = torch.device('cuda:0')
num_class = 120
# print(list(parent_array))

In [37]:
def plot(skeleton_motion1, skeleton_motion2, single=False, save_gif=False, save_name = "example"):
    '''
        Function to plot generated samples
        Input:
            skeleton_motion1: (64, 24, 3) skeleton of the person 1
            skeleton_motion2: (64, 24, 3) skeleton of the person 2
            single: True if it's a sinle person class else False
            save_gif: saves the sequence in a gif file
            save_name: Name pof the gif file
    '''
    fig = ipv.figure(width=600,height=600)
    skeleton_motion1[:,:,1] *= -1
    skeleton_motion1[:,:,2] *= -1
    
    skeleton_motion2[:,:,1] *= -1
    skeleton_motion2[:,:,2] *= -1
    
    s = ipv.scatter(skeleton_motion1[:,:,0],skeleton_motion1[:,:,1],skeleton_motion1[:,:,2],size=2.5,color='indigo',marker='sphere')
    if not single:
        s1 = ipv.scatter(skeleton_motion2[:,:,0],skeleton_motion2[:,:,1],skeleton_motion2[:,:,2],size=2.5,color='red',marker='sphere')
        anim_list = [s,s1]
    else:
        anim_list = [s]
    
    for i,p in enumerate(parent_array): # Run loop for each bone
        if p == -1:
            continue
        b = ipv.plot(np.array([skeleton_motion1[:,i,0],skeleton_motion1[:,p,0]]).T,np.array([skeleton_motion1[:,i,1],skeleton_motion1[:,p,1]]).T,np.array([skeleton_motion1[:,i,2],skeleton_motion1[:,p,2]]).T ,size=10, color='darkviolet')
        anim_list.append(b)
        if not single:
            b1 = ipv.plot(np.array([skeleton_motion2[:,i,0],skeleton_motion2[:,p,0]]).T,np.array([skeleton_motion2[:,i,1],skeleton_motion2[:,p,1]]).T,np.array([skeleton_motion2[:,i,2],skeleton_motion2[:,p,2]]).T ,size=10, color='orange')
            anim_list.append(b1)
    
    
#     ipv.plot_surface(x,y,z, color='red')
#     ipv.plot_wireframe(x,y,z,color='tan')
    ipv.animation_control(anim_list, interval=0.01)
    ipv.style.background_color('white')
    ipv.style.box_off()
    ipv.style.axes_off()
    ipv.show()
    
    if save_gif:
        def slide(figure, framenr, fraction):
            for a in anim_list:
                if a.sequence_index == skeleton_motion1.shape[0]:
                    a.sequence_index = 0
                a.sequence_index += 1        
        ipv.movie(save_name + '.gif', slide, fps=5, frames=skeleton_motion1.shape[0])

In [38]:
def fkt(x, mean_pose):
    '''
        This function takes joint rotatioins and t-pose and performs forward kinematics.
        input:
            x: (batch_size, 64, 24, 6)
            mean_pose: (24, 6)
        Returns:
            x: (batch_size, 64, 24, 3) 3-D skeleton
    '''
    # forward kinematics
    rotmat = rot6d_to_rotmat(x)
    # same mean pose across timesteps
    mean_pose = torch.tensor(mean_pose.reshape((1, -1)))
    mean_pose = mean_pose.expand((x.shape[0], x.shape[1], 72))
    mean_pose = mean_pose[:,:,:].reshape((x.shape[0]*x.shape[1],-1,3))
    rotmat = rotmat.reshape((x.shape[0]*x.shape[1],-1, 3, 3))
    pred = batch_rigid_transform(rotmat.float(),mean_pose.to(device).float(),parent_array)
    x = pred.reshape((x.shape[0], x.shape[1], -1))
    return x

In [39]:
def infer(model, class_label, num_sample=10):
    '''
        Function to generate sample using pretrained model
        Input: 
            model: saved model file
            class_label: class index
            num_sample: number of samples to be generated
        Return:
            pred_3d1: 3D skeleton of the first person
            pred_3d2: 3D skeleton of the second person
            seq_pred: predicted sequence length
    '''
    model.eval()
    # z = torch.randn(6, latent_dim).to(device).float()
    rot = np.load('./files/viewpoint.npy')
    # y = np.repeat(np.arange(3),2)
    # y = np.arange(num_class)
    y = np.repeat(class_label,num_sample)
    rot = np.repeat(rot[0:1], num_sample, axis=0)
#     print(rot.shape)
    # rot = torch.tensor(rot[:,0,:]).to(device).float()
    rot = rot[:,0,:].reshape((rot.shape[0],48,6))
    rot = rot[:,0,:]
    rot = torch.tensor(rot).to(device).float()

    label = np.zeros((y.shape[0], num_class))
    label[np.arange(y.shape[0]), y] = 1
    label = torch.tensor(label).to(device).float()
    with torch.no_grad():
        m, v = model.gaussian_parameters(model.z_pre.squeeze(0), dim=0)
        idx = torch.distributions.categorical.Categorical(model.pi).sample((label.shape[0],))
        m, v = m[idx], v[idx]
        z = model.sample_gaussian(m, v)
        
        z = torch.cat((z,label,rot), dim=1)
        z = model.latent2hidden(z)
        seq_pred = model.seq_decoder(z).cpu().data.numpy()
        z = z.reshape((z.shape[0], 4, -1))
        pred = model.decoder_net(z)
        root_pred = model.root_traj(z).unsqueeze(2)
        pred_3d1 = fkt(pred[:,:,:144].contiguous(), skeleton)
        pred_3d2 = fkt(pred[:,:,144:].contiguous(), skeleton)
        # pred_3d = fkt(pred, skeleton).cpu().data.numpy()
        pred_3d1 = pred_3d1.reshape((pred_3d1.shape[0], pred_3d1.shape[1], 24,-1)).cpu().data.numpy()
        pred_3d2 = pred_3d2.reshape((pred_3d2.shape[0], pred_3d2.shape[1], 24,-1)).cpu().data.numpy()
        root_pred = root_pred.cpu().data.numpy()
        pred_3d1 = pred_3d1 + root_pred[:,:,:,:3]
        pred_3d2 = pred_3d2 + root_pred[:,:,:,3:]

        return pred_3d1, pred_3d2, seq_pred

In [40]:
# load pretrained model
model = Model(120, latent_dim).to(device)
model.load_state_dict(torch.load('./checkpoints/' + 'model_199.pt', map_location=torch.device('cpu')))
print('model loaded..')

model loaded..


In [41]:
# class label : class index can be found in utils/class_index.txt
class_label = 58
p1, p2, seq = infer(model, class_label)

In [42]:
pred_len = seq[:,:,0]
pred_len = pred_len <= 0.975
pred_len = np.sum(pred_len, axis=1)
# print(pred_len)

In [43]:
idx = 1
plot(p1[idx,:pred_len[idx],:,:], p2[idx,:pred_len[idx],:,:])

VBox(children=(Figure(animation=0.01, camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=…