# Testing Human VAE

In [None]:
import torch
import numpy as np
import os

import networks

torch.autograd.set_detect_anomaly(True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
ckpt = input()

dirname = os.path.dirname(ckpt)
hyperparams = np.load(os.path.join(dirname,'hyperparams.npz'), allow_pickle=True)
training_args = hyperparams['args'].item()
ckpt = torch.load(ckpt)

model = getattr(networks, training_args.model)(**(hyperparams['vae_config'].item().__dict__)).to(device)
model.load_state_dict(ckpt['model'])
model.eval()

with np.load(training_args.src, allow_pickle=True) as data:
	test_data = torch.Tensor(np.array(data['test_data']).astype(np.float32)).to(device)
	N = test_data.shape[0]
	test_data = torch.concat([test_data[..., :model.input_dim], test_data[..., model.input_dim:]])
	data_shape = (N*2, model.window_size, model.num_joints, model.joint_dims)

In [None]:

x_gen, _, _ = model(test_data)
x_gen = x_gen.reshape(data_shape)
test_data = test_data.reshape(data_shape)

error = (test_data - x_gen)**2
print("Prediction MSE {:.4e}".format(error.flatten(0,2).sum(-1).mean().cpu().detach().numpy()))

x_gen = x_gen.cpu().detach().numpy()
test_data = test_data.cpu().detach().numpy()

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from utils import *
import asyncio

fig, ax = prepare_axis()
async def update():
    global ax
    for frame_idx in range(1000):
        ax = reset_axis(ax)
        ax = visualize_skeleton(ax, test_data[frame_idx], markerfacecolor='r', linestyle='-', alpha=0.5)
        ax = visualize_skeleton(ax, x_gen[frame_idx], markerfacecolor='m', linestyle='--', alpha=0.2)

        test_data[N+frame_idx, ..., 0] = 0.7 - test_data[N+frame_idx, ..., 0]
        test_data[N+frame_idx, ..., 1] = 0.2 - test_data[N+frame_idx, ..., 1]
        x_gen[N+frame_idx, ..., 0] = 0.7 - x_gen[N+frame_idx, ..., 0]
        x_gen[N+frame_idx, ..., 1] = 0.2 - x_gen[N+frame_idx, ..., 1]

        ax = visualize_skeleton(ax, test_data[N+frame_idx], markerfacecolor='b', linestyle='-', alpha=0.5)
        ax = visualize_skeleton(ax, x_gen[N+frame_idx], markerfacecolor='g', linestyle='--', alpha=0.2)
        
        fig.canvas.draw_idle()
        fig.canvas.flush_events()
        await asyncio.sleep(0.001)
loop = asyncio.get_event_loop()
loop.create_task(update());
