# Visualizing Latent dimesnions

In [1]:
# %matplotlib widget
import numpy as np

# import matplotlib.pyplot as plt

from plotly.offline import init_notebook_mode
init_notebook_mode()
import plotly.graph_objs as go
from plotly.subplots import make_subplots

import numpy as np

import torch
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from mild_hri.vae import VAE
from mild_hri.utils import *
from mild_hri.dataloaders import *

import pbdlib_torch as pbd_torch
from pbdlib.functions import multi_variate_normal

def smoothing(indices):
	newIndices = indices
	for i in range(1,len(indices)):
		if indices[i] != indices[i-1] and indices[i] != indices[i+1] and indices[i+1] == indices[i-1]:
			newIndices[i] = indices[i+1]

	return newIndices

ckpt_path = '../logs/2023/nuisiv2_3joints_xvel/z5h6_old/trial3/models/180.pth'
ckpt = torch.load(ckpt_path)
args = ckpt['args']
ssm_list = ckpt['ssm']
z_dim = ssm_list[0].nb_dim//2
if args.dataset == 'buetepage':
	dataset = buetepage.HHWindowDataset
elif args.dataset == 'nuisi':
	dataset = nuisi.HHWindowDataset

print("Reading Data")
train_iterator = DataLoader(dataset('../'+args.src, train=True, window_length=args.window_size, downsample=args.downsample), batch_size=1, shuffle=True)
test_iterator = DataLoader(dataset('../'+args.src, train=False, window_length=args.window_size, downsample=args.downsample), batch_size=1, shuffle=False)

model = VAE(**(args.__dict__)).to(device)
model.load_state_dict(ckpt['model'])


z_encoded = []
z_segments = []
z_transitionstates = []
transitionstates_gmm = []

fig = make_subplots(rows=1, cols=4,
					specs=[[{'is_3d': True} for i in range(4)]],
					print_grid=False)

for a in range(len(ssm_list)):
	ssm = ssm_list[a]
	for i in range(ssm.nb_states):
		pbd_torch.plot_gauss3d(fig, ssm.mu[i, :3].detach().cpu().numpy(), ssm.sigma[i, :3, :3].detach().cpu().numpy(),
					color='red', alpha=0.2, row=1, col=a+1)
		fig.add_trace(go.Scatter3d(
			x=ssm.mu[i, 0].detach().cpu().numpy(),
			y=ssm.mu[i, 1].detach().cpu().numpy(),
			z=ssm.mu[i, 2].detach().cpu().numpy(),
			mode="text",
			text=[str(i+1)],
			textposition="bottom center"
		), row=1, col=a+1)
		pbd_torch.plot_gauss3d(fig, ssm.mu[i, z_dim:z_dim+3].detach().cpu().numpy(), ssm.sigma[i, z_dim:z_dim+3, z_dim:z_dim+3].detach().cpu().numpy(),
					color='blue', alpha=0.2, row=1, col=a+1)
		fig.add_trace(go.Scatter3d(
			x=ssm.mu[i, z_dim+0].detach().cpu().numpy(),
			y=ssm.mu[i, z_dim+1].detach().cpu().numpy(),
			z=ssm.mu[i, z_dim+2].detach().cpu().numpy(),
			mode="text",
			text=[str(i+1)],
			textposition="bottom center"
		), row=1, col=a+1)

for a in range(len(train_iterator.dataset.actidx)):
	s = train_iterator.dataset.actidx[a]
	z_encoded.append([])
	z_segments.append([])
	z_transitionstates.append([])
	for j in range(s[0], s[1]):
	# for j in np.random.randint(s[0], s[1], 12):
		x, label = train_iterator.dataset[j]
		x = torch.Tensor(x).to(device)
		seq_len, dims = x.shape
		x = torch.concat([x[None, :, :dims//2], x[None, :, dims//2:]]) # x[0] = Agent 1, x[1] = Agent 2
		
		zh = model(x[0], encode_only=True)
		# zh_diff = torch.diff(zh,dim=0,prepend=zh[0:1,:])
		zr = model(x[1], encode_only=True)
		zt = torch.concat([zh, zr], dim=-1) # (num_trajs, seq_len, 2*z_dim)
		alpha_t = ssm_list[a].forward_variable(zt).argmax(0)
		alpha_th = ssm_list[a].forward_variable(zh, marginal=slice(0, z_dim)).argmax(0)
		# print(alpha_t!=alpha_th)
		# print(alpha_t+1)
		# print(alpha_th+1)
		# print('\n')
		t1 = alpha_t!=alpha_th
		t2 = alpha_th==0

		
		# transitions = torch.diff(alpha_t)
		# transitions_idx = torch.where(transitions!=0)
		# # print(transitions_idx)
		# t1 = alpha_t[transitions_idx]==0
		# t2 = alpha_t[transitions_idx[0]+1]==1
		t = t1 * t2
		first_transition_idx = torch.where(t)
		if len(first_transition_idx[0])==0:
			continue
		# print(first_transition_idx)
		z_segments.append(alpha_t)
		z_encoded.append(zt)
		# print(zt[first_transition_idx].shape)
		z_transitionstates[-1].append(zt[first_transition_idx].detach().cpu().numpy())

		fig.add_trace(go.Scatter3d(
						x=z_transitionstates[-1][-1][:, 0],
						y=z_transitionstates[-1][-1][:, 1],
						z=z_transitionstates[-1][-1][:, 2],
						mode='markers',
						marker=dict(
							color='magenta',
							opacity=0.2
						)
					), row=1, col=a+1)
		fig.add_trace(go.Scatter3d(
						x=z_transitionstates[-1][-1][:, z_dim+0],
						y=z_transitionstates[-1][-1][:, z_dim+1],
						z=z_transitionstates[-1][-1][:, z_dim+2],
						mode='markers',
						marker=dict(
							color='cyan',
							opacity=0.2
						)
					), row=1, col=a+1)

	ssm_np = pbd.GMM(nb_dim=ssm_list[a].nb_dim, nb_states=1)
	z_transitionstates[-1] = np.vstack(z_transitionstates[-1])
	ssm_np.em(z_transitionstates[-1])
	print(ssm_np.mu.tolist())
	print(ssm_np.sigma.tolist())
	pbd_torch.plot_gauss3d(fig, ssm_np.mu[0, :3], ssm_np.sigma[0, :3, :3],
				color='magenta', alpha=0.4, row=1, col=a+1)
	pbd_torch.plot_gauss3d(fig, ssm_np.mu[0, z_dim:z_dim+3], ssm_np.sigma[0, z_dim:z_dim+3, z_dim:z_dim+3],
				color='cyan', alpha=0.4, row=1, col=a+1)
	for z in z_transitionstates[-1]:
		# B1, _ = ssm_list[a].obs_likelihood(torch.Tensor(z[None]).to(device))
		# B2, _ = ssm_list[a].obs_likelihood(torch.Tensor(z[None, :z_dim]).to(device), marginal=slice(0,z_dim))
		B1 = multi_variate_normal(z[None, :z_dim], ssm_list[a].mu[0, :z_dim].detach().cpu().numpy(), ssm_list[a].sigma[0, :z_dim, :z_dim].detach().cpu().numpy(), log=True)
		B2 = multi_variate_normal(z[None, :z_dim], ssm_np.mu[0, :z_dim], ssm_np.sigma[0, :z_dim, :z_dim], log=True)
		# print(B1[0][0].item())
		# print(B2[0][0].item())
	# 	print(B1[0], B2[0], B1[0]>B2[0])
	# print('\n')

	for j in range(s[0], s[1]):
		x, label = train_iterator.dataset[j]
		x = torch.Tensor(x).to(device)
		seq_len, dims = x.shape
		x = torch.concat([x[None, :, :dims//2], x[None, :, dims//2:]]) # x[0] = Agent 1, x[1] = Agent 2
		
		zh = model(x[0], encode_only=True)
		zr = model(x[1], encode_only=True)
		zt = torch.concat([zh, zr], dim=-1) # (num_trajs, seq_len, 2*z_dim)
		alpha_t = ssm_list[a].forward_variable(zt).argmax(0)
		B1 = multi_variate_normal(zh[:, :z_dim].detach().cpu().numpy(), ssm_np.mu[0, :z_dim], ssm_np.sigma[0, :z_dim, :z_dim], log=False)
		alpha_th = ssm_list[a].forward_variable(zh, marginal=slice(0, z_dim))
		print((alpha_t+1).tolist())
		print((alpha_th.argmax(0)+1).tolist())
		t1 = alpha_th[0].detach().cpu().numpy()>=B1
		t2 = alpha_th[-1].detach().cpu().numpy()>=B1
		print(t1.astype(int).tolist())
		print(t2.astype(int).tolist())
		print('\n')
	print('\n')

fig.update_layout(height=750, width=3000)
fig.show()

Reading Data
[[-1.2382518768310549, -1.4311123847961427, -0.08004970401525498, -0.430169415473938, 0.6422056555747987, -0.7152567923069001, -0.832540822029114, 0.22498532012104988, -0.4915773153305054, 0.6321725964546203]]
[[[0.41748771890732495, -0.01614697677472037, 0.06724723462632383, -0.09783424389470718, -0.12257544169805998, 0.608921967625613, 0.019568680480818872, 0.05818807148044699, -0.2444983036386644, -0.1492431850067915], [-0.01614697677472037, 0.004257474357987221, -0.00019754079431272983, 0.006080039657556142, 0.004711294805471766, -0.02141880527861247, 0.002340120854217105, -0.0020422651163973355, 0.010516179852255619, 0.00611780897838912], [0.06724723462632383, -0.00019754079431272937, 0.012482194998355371, -0.01418141618745298, -0.01978192899791349, 0.09929525999635189, 0.004771313883337811, 0.009316678172781784, -0.038782718158978886, -0.023908475452198683], [-0.09783424389470717, 0.006080039657556142, -0.01418141618745298, 0.02516579568738524, 0.02934107035776102, -