In [1]:
import numpy as np
import ipyvolume as ipv
import h5py
import os

import torch 
import torch.nn as nn

from model import *
from rotation import rot6d_to_rotmat, rot6d_to_rotmat, batch_rigid_transform
from torch.autograd import Variable
import time

In [2]:
parent_array = np.array([-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 20, 23, 20,
                   25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 21, 21, 38, 21, 40, 41, 21, 43, 44, 21, 46, 
                   47, 21, 49, 50])
print(len(parent_array))
skeleton = np.load('./files/skeleton.npy')

latent_dim = 768
device = torch.device('cuda:0')
num_class = 120
main_path = "/ssd_scratch/cvit/debtanu.gupta/"

52


In [3]:
def plot(skeleton_motion1, skeleton_motion2, single=False, save_gif=False, save_name = "example"):
    fig = ipv.figure(width=600,height=600)
#     skeleton_motion1[:,:,1] *= -1
#     skeleton_motion1[:,:,2] *= -1
    
#     skeleton_motion2[:,:,1] *= -1
#     skeleton_motion2[:,:,2] *= -1
    
    anim_list = []
    for i,p in enumerate(parent_array): # Run loop for each bone
        if p == -1:
            continue
        b = ipv.plot(np.array([skeleton_motion1[:,i,0],skeleton_motion1[:,p,0]]).T,np.array([skeleton_motion1[:,i,1],skeleton_motion1[:,p,1]]).T,np.array([skeleton_motion1[:,i,2],skeleton_motion1[:,p,2]]).T ,size=10, color='darkviolet')
        anim_list.append(b)
        if not single:
            b1 = ipv.plot(np.array([skeleton_motion2[:,i,0],skeleton_motion2[:,p,0]]).T,np.array([skeleton_motion2[:,i,1],skeleton_motion2[:,p,1]]).T,np.array([skeleton_motion2[:,i,2],skeleton_motion2[:,p,2]]).T ,size=10, color='orange')
            anim_list.append(b1)
    
    
#     ipv.plot_surface(x,y,z, color='red')
#     ipv.plot_wireframe(x,y,z,color='tan')
    ipv.animation_control(anim_list, interval=0.01)
    ipv.style.background_color('white')
    ipv.style.box_off()
    ipv.style.axes_off()
    ipv.show()
    
    if save_gif:
        def slide(figure, framenr, fraction):
            for a in anim_list:
                if a.sequence_index == skeleton_motion1.shape[0]:
                    a.sequence_index = 0
                a.sequence_index += 1        
        ipv.movie(save_name + '.gif', slide, fps=5, frames=skeleton_motion1.shape[0])

In [4]:
def fkt(x, mean_pose):
    # forward kinematics
    rotmat = rot6d_to_rotmat(x)
    # same mean pose across timesteps
    mean_pose = torch.tensor(mean_pose.reshape((1, -1)))
    mean_pose = mean_pose.expand((x.shape[0], x.shape[1], 156))
    mean_pose = mean_pose[:,:,:].reshape((x.shape[0]*x.shape[1],-1,3))
    rotmat = rotmat.reshape((x.shape[0]*x.shape[1],-1, 3, 3))
    pred = batch_rigid_transform(rotmat.float(),mean_pose.to(device).float(),parent_array)
    x = pred.reshape((x.shape[0], x.shape[1], -1))
    return x

In [5]:
def infer(model, epoch, components=0, num_samples=10):
    model.eval()
    rot = np.load('./files/camera.npy')
    y = np.repeat(epoch,num_samples)
    rot = np.repeat(rot[0:1], num_samples, axis=0)
    rot = torch.tensor(rot).to(device).float()

    label = np.zeros((y.shape[0], num_class))
    label[np.arange(y.shape[0]), y] = 1
    label = torch.tensor(label).to(device).float()
#     print(rot.shape, label.shape)
    with torch.no_grad():
        m, v = model.gaussian_parameters(model.z_pre.squeeze(0), dim=0)
        idx = torch.distributions.categorical.Categorical(model.pi).sample((label.shape[0],))
        m, v = m[idx], v[idx]
        z = model.sample_gaussian(m, v)
        N = z.shape[0]
        z = torch.cat((z, rot,label), dim=1)
        z = model.latent2hidden(z)
        seq_pred = model.seq_decoder(z).cpu().data.numpy()
        z_body = z[:,:320] # for decoding body joints
        z_hand = z[:,320:] # for decoding hand joints
        root_pred = model.root_traj(z_body).unsqueeze(2).cpu().data.numpy()
        z_body = z_body.reshape((N,4,-1))
        z_hand = z_hand.reshape((N,4,-1))
        
        x = model.decoder_net(z_body)
        hand_x = model.decoder_net_hand(z_hand)
        
        print(root_pred.shape)
        x = x.reshape(N, 64, 2, 22, -1)
        hand_x = hand_x.reshape((N, 64, 2, 30, -1))
        
        pred = torch.cat((x, hand_x), dim=3)
        pred_3d1 = fkt(pred[:,:,0,:,:].contiguous(), skeleton)
        pred_3d2 = fkt(pred[:,:,1,:,:].contiguous(), skeleton)
        pred_3d1 = pred_3d1.reshape((pred_3d1.shape[0], 64, 52, -1)).cpu().data.numpy()
        pred_3d2 = pred_3d2.reshape((pred_3d2.shape[0], 64, 52, -1)).cpu().data.numpy()
        pred = rot6d_to_rotmat(pred)
        pred = pred.reshape((pred_3d1.shape[0], 64, 2, 52, 3,3)).cpu().data.numpy()
        print(root_pred.shape)
        pred_3d1 = pred_3d1 + root_pred[:,:,:,:3]
        pred_3d2 = pred_3d2 + root_pred[:,:,:,3:]

        return pred_3d1, pred_3d2, pred, seq_pred, root_pred

In [14]:
model = Model(num_class, latent_dim).to(device)
model.load_state_dict(torch.load('./checkpoints/' + 'model_121.pt', map_location=torch.device('cpu')))
print('model loaded..')

model loaded..


In [21]:
p1, p2, rot6d, seq_pred, root_pred = infer(model,16, components=27, num_samples=50)
print(p1.shape, p2.shape, rot6d.shape, seq_pred.shape)

(50, 64, 1, 6)
(50, 64, 1, 6)
(50, 64, 52, 3) (50, 64, 52, 3) (50, 64, 2, 52, 3, 3) (50, 64, 1)


In [22]:
pred_len = seq_pred[:,:,0]
pred_len = pred_len <= 0.975
pred_len = np.sum(pred_len, axis=1)
print(pred_len)

[18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 17
 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 17 18 18 18 18 18 18 18
 18 18]


In [24]:
idx = 10
plot(p1[idx,:,:,:], p2[idx,:,:,:], single=False)

VBox(children=(Figure(animation=0.01, camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=…

In [27]:
p = p1[idx]

In [28]:
print(p[:].shape)

(64, 52, 3)


In [30]:
np.save("take_off_a_shoe_without_ST_x.npy", p)

In [31]:
np.load("take_off_a_shoe_without_ST_x.npy").shape

(64, 52, 3)

In [15]:
print(root_pred[idx].shape, rot6d[idx,:,0].shape)

(64, 1, 6) (64, 52, 3, 3)


In [23]:
f = h5py.File("Pushing.h5", "w")
f.create_dataset("rotation", data=rot6d[idx,1:])
f.create_dataset("root", data=root_pred[idx,1:])
f.close()

In [16]:
np.save("Monolythic_drink.npy", rot6d[idx,:,0])