# Visualizing Latent dimesnions

In [4]:
%matplotlib widget
import numpy as np
import scipy
import matplotlib.pyplot as plt
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from vae import *
import dataloaders

import pbdlib as pbd
import pbdlib_torch as pbd_torch

import torch
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_path = 'logs/2023aug/bp_pepper_downsampled_robotfuture/z03h05/diagvaehmmgmm/models/init_ckpt.pth'
ckpt = torch.load(ckpt_path)
hyperparams = np.load(os.path.join(os.path.dirname(ckpt_path),'hyperparams.npz'), allow_pickle=True)
args_ckpt = hyperparams['args'].item()
ae_config = hyperparams['ae_config'].item()
if 'robot_vae_config' in hyperparams.keys():
	robot_vae_config = hyperparams['robot_vae_config'].item()
else:
	robot_vae_config = hyperparams['ae_config'].item()

model_h = VAE(**(ae_config.__dict__)).to(device)
model_h.load_state_dict(ckpt['model_h'])
model_h.eval()
model_r = VAE(**(robot_vae_config.__dict__)).to(device)
model_r.load_state_dict(ckpt['model_r'])
model_r.eval()
z_dim = args_ckpt.latent_dim

if args_ckpt.dataset == 'buetepage_pepper':
	dataset = dataloaders.buetepage.PepperWindowDataset
elif args_ckpt.dataset == 'buetepage':
	dataset = dataloaders.buetepage.HHWindowDataset
# TODO: Nuitrack

train_iterator = DataLoader(dataset(args_ckpt.src, train=True, window_length=args_ckpt.window_size, downsample=args_ckpt.downsample), batch_size=1, shuffle=True)
hsmm=[]
with torch.no_grad():
	for a in range(len(train_iterator.dataset.actidx)):
		hsmm.append(pbd_torch.HMM(nb_dim=args_ckpt.latent_dim*2, nb_states=args_ckpt.hsmm_components))
		s = train_iterator.dataset.actidx[a]
		z_encoded = []
		for j in range(s[0], s[1]):
			# for j in np.random.randint(s[0], s[1], 12):
			x, label = train_iterator.dataset[j]
			x = torch.Tensor(x).to(device)
			x_h = x[:, :model_h.input_dim]
			x_r = x[:, model_h.input_dim:]
			
			z_h = model_h(x_h, encode_only=True)
			z_r = model_r(x_r, encode_only=True)
			z_encoded.append(torch.concat([z_h, z_r], dim=-1).cpu().numpy()) # (num_trajs, seq_len, 2*z_dim)
		hsmm_np = pbd.HMM(nb_dim=args_ckpt.latent_dim*2, nb_states=args_ckpt.hsmm_components)
		hsmm_np.init_params_scikit(np.concatenate(z_encoded))
		hsmm_np.em(z_encoded, reg=args_ckpt.cov_reg, reg_finish=args_ckpt.cov_reg)
		hsmm[a].reg = torch.Tensor(hsmm_np.reg).to(device).requires_grad_(False)
		hsmm[a].mu = torch.Tensor(hsmm_np.mu).to(device).requires_grad_(False)
		hsmm[a].sigma = torch.Tensor(hsmm_np.sigma).to(device).requires_grad_(False)
		hsmm[a].priors = torch.Tensor(hsmm_np.priors).to(device).requires_grad_(False)
		hsmm[a].trans = torch.Tensor(hsmm_np.trans).to(device).requires_grad_(False)
		hsmm[a].Trans = torch.Tensor(hsmm_np.Trans).to(device).requires_grad_(False)
		hsmm[a].init_priors = torch.Tensor(hsmm_np.init_priors).to(device).requires_grad_(False)

# hsmm = ckpt['hsmm']

test_dataset = dataset(args_ckpt.src, train=False, window_length=model_h.window_size, downsample=args_ckpt.downsample)
actions = ['Hand Wave', 'Hand Shake', 'Rocket Fistbump', 'Parachute Fistbump']

fig = plt.figure()
ax_dists = []
ax_alpha= [] 
ax_trans = []
for i in range(4):
	ax_dists.append(fig.add_subplot(3, 4, i+1, projection='3d'))
	ax_alpha.append(fig.add_subplot(3, 4, 4+i+1))
	ax_trans.append(fig.add_subplot(3, 4, 8+i+1))
actidx = np.hstack(test_dataset.actidx - np.array([0,1]))
for a in actidx:
	x, label = test_dataset[a]
	seq_len = x.shape[0]
	dims_h = model_h.input_dim
	x = torch.Tensor(x).to(device)
	x_h = x[:, :dims_h]
	x_r = x[:, dims_h:]
	
	zh_post = model_h(x_h, dist_only=True)
	zr_post = model_r(x_r, dist_only=True)
	ax_dists[label].scatter3D(zh_post.mean[::5, 0].detach().cpu().numpy(), zh_post.mean[::5, 1].detach().cpu().numpy(), zh_post.mean[::5, 2].detach().cpu().numpy(), 'r.')
	ax_dists[label].scatter3D(zr_post.mean[::5, 0].detach().cpu().numpy(), zr_post.mean[::5, 1].detach().cpu().numpy(), zr_post.mean[::5, 2].detach().cpu().numpy(), 'b.')

	alpha = hsmm[label].forward_variable(marginal=[], sample_size=100).detach().cpu().numpy()
	alpha_h = hsmm[label].forward_variable(demo=zh_post.mean, marginal=slice(0, z_dim)).detach().cpu().numpy()
	ax_alpha[label].plot(np.linspace(0, 1, 100), alpha.T, linestyle='-')
	ax_alpha[label].plot(np.linspace(0, 1, seq_len), alpha_h.T, linestyle='--')

	ax_trans[label].imshow(hsmm[label].Trans.detach().cpu().numpy())
	
	for i in range(hsmm[label].nb_states):
		pbd.plot_gauss3d(ax_dists[label], hsmm[label].mu[i, :3].detach().cpu().numpy(), hsmm[label].sigma[i, :3, :3].detach().cpu().numpy(),
					color='red', alpha=0.1)
		pbd.plot_gauss3d(ax_dists[label], hsmm[label].mu[i, z_dim:z_dim+3].detach().cpu().numpy(), hsmm[label].sigma[i, z_dim:z_dim+3, z_dim:z_dim+3].detach().cpu().numpy(),
					color='blue', alpha=0.1)
		
	# break


RuntimeError: mat1 and mat2 shapes cannot be multiplied (161x35 and 20x40)