# Testing Human VAE

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import *
import numpy as np
import os

import networks
from config import *

from phd_utils.dataloaders import *

torch.autograd.set_detect_anomaly(True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tdm_ckpt = input()

tdm_hyperparams = np.load(os.path.join(os.path.dirname(tdm_ckpt),'tdm_hyperparams.npz'), allow_pickle=True)
tdm_args = tdm_hyperparams['args'].item()
tdm_1 = networks.TDM(**(tdm_hyperparams['tdm_config'].item().__dict__)).to(device)
tdm_1.load_state_dict(torch.load(tdm_ckpt)['model_1'])
tdm_2 = networks.TDM(**(tdm_hyperparams['tdm_config'].item().__dict__)).to(device)
tdm_2.load_state_dict(torch.load(tdm_ckpt)['model_2'])
tdm_1.eval()
tdm_2.eval()

vae_hyperparams = np.load(os.path.join(os.path.dirname(tdm_args.vae_ckpt),'hyperparams.npz'), allow_pickle=True)
vae_args = vae_hyperparams['args'].item()
vae = networks.VAE(**(vae_hyperparams['config'].item().__dict__)).to(device)

vae.load_state_dict(torch.load(tdm_args.vae_ckpt)['model'])
vae.eval()


if vae_args.model =='BP_HH':
	dataset = buetepage.HHWindowDataset
elif vae_args.model =='NUISI_HH':
	dataset = nuisi.HHWindowDataset
elif vae_args.model =='ALAP':
	dataset = alap.HHWindowDataset

if vae_args.model == 'BP_HH' or vae_args.model == 'NUISI_HH':
	tdm_config = human_tdm_config()
	p1_tdm_idx = np.concatenate([np.arange(18),np.arange(-4,0)])
	p2_tdm_idx = np.concatenate([90+np.arange(18),np.arange(-4,0)])
	p1_vae_idx = np.arange(90)
	p2_vae_idx = np.arange(90) + 90
elif vae_args.model == 'ALAP':
	tdm_config = handover_tdm_config()
	p1_tdm_idx = np.concatenate([np.arange(36),np.arange(-2,0)])
	p2_tdm_idx = np.concatenate([180+np.arange(36),np.arange(-2,0)])
	p1_vae_idx = np.arange(180)
	p2_vae_idx = np.arange(180) + 180

test_dataset = dataset(train=False, window_length=vae_hyperparams['global_config'].item().WINDOW_LEN, downsample=0.2)
test_data = [torch.Tensor(t) for t in test_dataset.traj_data]
actidx = test_dataset.actidx

In [None]:
x2_tdm_out = []
x1_in = []
x2_gt = []
z1_x1 = []
z2_x2 = []
z1_d1 = []
z2_d1 = []
mse_actions = []

for a in actidx:
	mse_actions.append([])
	for i in range(a[0],a[1]):
		x = test_data[i]
		seq_len, dims = x.shape
		x1_tdm = x[:,p1_tdm_idx].to(device)
		x1_vae = x[:,p1_vae_idx].to(device)
		x2_vae = x[:,p2_vae_idx].to(device)

		_, z1post_samples, z1post_dist = vae(x1_vae)
		_, z2post_samples, z2post_dist = vae(x2_vae)

		z1_x1.append(z1post_dist.mean.detach().cpu().numpy())
		z2_x2.append(z2post_dist.mean.detach().cpu().numpy())

		x1_in.append(x1_vae.detach().cpu().numpy().reshape((-1, 5, 3, 6)))
		x2_gt.append(x2_vae.detach().cpu().numpy().reshape((-1, 5, 3, 6)))
		
		z1_d1_dist, d1_samples, d1_dist, _ = tdm_1(x1_tdm, None)
		z1_d1.append(z1_d1_dist.mean.detach().cpu().numpy())
		z2_d1_mean = tdm_2.output_mean(tdm_2._decoder(d1_dist.mean))
		z2_d1.append(z2_d1_mean.detach().cpu().numpy())
		x2_pred = vae._output(vae._decoder(z2_d1_mean))
		x2_tdm_out.append(x2_pred.detach().cpu().numpy().reshape((-1, 5, 3, 6)))

		mse_actions[-1] += ((x2_pred - x2_vae)**2).reshape((seq_len, vae.window_size, vae.num_joints, vae.joint_dims)).sum(-1).mean(-1).mean(-1).detach().cpu().numpy().tolist()

s = ''
for mse in mse_actions:
	s += f'{np.mean(mse):.4e}$\pm${np.std(mse):.4e}\t'
print(s)


In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from phd_utils.visualization import *
import asyncio

x1_in = np.vstack(x1_in)
x2_gt = np.vstack(x2_gt)
x2_tdm_out = np.vstack(x2_tdm_out)

z1_x1 = np.vstack(z1_x1)
z2_x2 = np.vstack(z2_x2)
z1_d1 = np.vstack(z1_d1)
z2_d1 = np.vstack(z2_d1)

fig = plt.figure(figsize=(14,7))
ax = fig.add_subplot(1,2,1, projection='3d')
# plt.ion()
ax.view_init(25, -155)
ax.set_xlim3d([-0.05, 0.75])
ax.set_ylim3d([-0.3, 0.5])
ax.set_zlim3d([-0.8, 0.2])
ax2 = fig.add_subplot(1,2,2, projection='3d')
N = x1_in.shape[0]
async def update():
    global ax
    for frame_idx in range(0,N,5):
        ax = reset_axis(ax)
        ax = visualize_skeleton(ax, x1_in[frame_idx], markerfacecolor='r', linestyle='-')
        # ax = visualize_skeleton(ax, x1_vae_out[frame_idx], markerfacecolor='m', linestyle='--')
        # ax = visualize_skeleton(ax, x_gen[frame_idx], markerfacecolor='m', linestyle='--', alpha=0.2)

        x2_gt[frame_idx, ..., 0] = 0.7 - x2_gt[frame_idx, ..., 0]
        x2_gt[frame_idx, ..., 1] = 0.2 - x2_gt[frame_idx, ..., 1]
        x2_tdm_out[frame_idx, ..., 0] = 0.7 - x2_tdm_out[frame_idx, ..., 0]
        x2_tdm_out[frame_idx, ..., 1] = 0.2 - x2_tdm_out[frame_idx, ..., 1]

        ax = visualize_skeleton(ax, x2_gt[frame_idx], markerfacecolor='b', linestyle='-')
        ax = visualize_skeleton(ax, x2_tdm_out[frame_idx], markerfacecolor='c', linestyle='--')
        # ax = visualize_skeleton(ax, x2_gen[frame_idx], markerfacecolor='g', linestyle='--', alpha=0.2)

        ax2.scatter(z1_x1[frame_idx, 0], z1_x1[frame_idx, 1], z1_x1[frame_idx, 2], color='r', marker='o')
        ax2.scatter(z2_x2[frame_idx, 0], z2_x2[frame_idx, 1], z2_x2[frame_idx, 2], color='b', marker='o')
        ax2.scatter(z1_d1[frame_idx, 0], z1_d1[frame_idx, 1], z1_d1[frame_idx, 2], color='m', marker='o')
        ax2.scatter(z2_d1[frame_idx, 0], z2_d1[frame_idx, 1], z2_d1[frame_idx, 2], color='c', marker='^')
        
        fig.canvas.draw_idle()
        fig.canvas.flush_events()
        await asyncio.sleep(0.001)
loop = asyncio.get_event_loop()
loop.create_task(update());
