# Testing Human VAE

In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import pbdlib as pbd

import networks
import config
from utils import *
import dataloaders

torch.autograd.set_detect_anomaly(True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
model_types = "ablation_vae"
z_dims = [10]
z_trajs = [[], [], [], []]

for i in range(len(z_dims)):
	z_dim = z_dims[i]
	ckpt = f'logs/2023/{model_types}/z{z_dim}/trial0/models/final.pth' # input()
	dirname = os.path.dirname(ckpt)
	hyperparams = np.load(os.path.join(dirname,'hyperparams.npz'), allow_pickle=True)
	args = hyperparams['args'].item()
	ckpt = torch.load(ckpt)

	model = getattr(networks, args.model)(**(hyperparams['ae_config'].item().__dict__)).to(device)
	model.load_state_dict(ckpt['model'])
	model.eval()
	z_dim = model.latent_dim
	if model.window_size == 1:
		nb_dim = 4*model.latent_dim
	else:
		nb_dim = 2*model.latent_dim
	# dataset = getattr(dataloaders, args.dataset)
	if model.window_size ==1:
		train_iterator = DataLoader(dataloaders.buetepage.SequenceDataset(args.src, train=True), batch_size=1, shuffle=True)
		test_iterator = DataLoader(dataloaders.buetepage.SequenceDataset(args.src, train=False), batch_size=1, shuffle=True)
	else:
		train_iterator = DataLoader(dataloaders.buetepage.SequenceWindowDataset(args.src, train=True, window_length=model.window_size), batch_size=1, shuffle=True)
		test_iterator = DataLoader(dataloaders.buetepage.SequenceWindowDataset(args.src, train=False, window_length=model.window_size), batch_size=1, shuffle=True)
	for a in range(len(train_iterator.dataset.actidx)):
		s = train_iterator.dataset.actidx[a]
		z_encoded = []
		for j in range(s[0], s[1]):
		# for j in np.random.randint(s[0], s[1], 12):
			x, label = train_iterator.dataset[j]
			assert np.all(label == a)
			x = torch.Tensor(x).to(device)
			seq_len, dims = x.shape
			x = torch.concat([x[None, :, :dims//2], x[None, :, dims//2:]]) # x[0] = Agent 1, x[1] = Agent 2
			
			zpost_samples = model(x, encode_only=True)
			if model.window_size == 1:
				z1_vel = torch.diff(zpost_samples[0], prepend=zpost_samples[0][0:1], dim=0)
				z2_vel = torch.diff(zpost_samples[1], prepend=zpost_samples[1][0:1], dim=0)
				z_encoded.append(torch.concat([zpost_samples[0], z1_vel, zpost_samples[1], z2_vel], dim=-1).detach().cpu().numpy()) # (num_trajs, seq_len, 2*z_dim)
			else:
				z_encoded.append(torch.concat([zpost_samples[0], zpost_samples[1]], dim=-1).detach().cpu().numpy()) # (num_trajs, seq_len, 2*z_dim)
		z_encoded = np.array(z_encoded)
		z_trajs[i].append(z_encoded)
	for nb_states in [4,5,8,10]:
		hsmm = [pbd.HSMM(nb_dim=nb_dim, nb_states=nb_states) for a in range(len(train_iterator.dataset.actidx))]
		
		mse_error = []
		for a in range(len(train_iterator.dataset.actidx)):
			train_idx = np.random.choice(np.arange(len(z_trajs[i][a])),10,False).astype(int)
			hsmm[a].init_hmm_kbins(z_trajs[i][a][train_idx])
			hsmm[a].em(z_trajs[i][a][train_idx], nb_max_steps=60)

		# predictions = []
		# for a in range(len(train_iterator.dataset.actidx)):
			for z_traj in z_trajs[i][a]:
				z_traj = np.array(z_traj)
				if model.window_size >1:
					z2_pred, sigma2 = hsmm[a].condition(z_traj[:, :z_dim], None, dim_in=slice(0, z_dim), dim_out=slice(z_dim, 2*z_dim))
					# z2 = hsmm[label].condition(zpost_dist.mean[0].detach().cpu().numpy(), dim_in=slice(0, z_dim), dim_out=slice(z_dim, 2*z_dim))
				else:
					z1_vel = np.diff(z_traj, prepend=z_traj[0:1, :z_dim], axis=0)
					z2_pred, sigma2 = hsmm[a].condition(torch.concat([z_traj[:, :z_dim], z1_vel], dim=-1), None, dim_in=slice(0, 2*z_dim), dim_out=slice(2*z_dim, 3*z_dim))
				
				mse_i = ((z_traj[:, z_dim:] - z2_pred)**2).sum(-1)
				mse_error += mse_i.tolist()
				# recon_error += recon_i.tolist()
		print(f"| {nb_states} | {z_dim} | {np.mean(mse_error):.4e} ± {np.std(mse_error):.4e}")
				# print(f'MSE: {np.sum(mse_error)}')
	print('')


  z_encoded = np.array(z_encoded)


| 4 | 10 | 5.0636e-01 ± 9.5254e-01
| 5 | 10 | 5.0876e-01 ± 9.4302e-01


KeyboardInterrupt: 