# Visualizing Latent dimesnions

In [None]:
# %matplotlib widget
import numpy as np

# import matplotlib.pyplot as plt

from plotly.offline import init_notebook_mode
init_notebook_mode()
import plotly.graph_objs as go
from plotly.subplots import make_subplots

import numpy as np
import torch

import pbdlib_torch as pbd_torch

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from utils import *
from dataset import HumanHandoverDataset
import networks

fig_latent = make_subplots(rows=4, cols=4,
					specs=[[{'is_3d': True} for i in range(4)]for v in range(4)],
					print_grid=False)
fig_alpha = make_subplots(rows=4, cols=4,
					specs=[[{'is_3d': False} for i in range(4)]for v in range(4)],
					print_grid=False)
fig_recon = make_subplots(rows=4, cols=4,
					specs=[[{'is_3d': True} for i in range(4)]for v in range(4)],
					print_grid=False)

colors = ['red', 'green', 'blue']
object_color = ['#000000', '#555555', '#aaaaaa']
for trial in range(4):
	ckpt_path = f'logs/moveint/trial{trial}/models/399.pth'
	print(ckpt_path)
	ckpt = torch.load(ckpt_path)
	training_args = ckpt['args']

	test_iterator = DataLoader(HumanHandoverDataset(training_args, train=False), batch_size=1, shuffle=True)

	model = networks.MoVEInt(test_iterator.dataset.input_dims, test_iterator.dataset.output_dims, training_args).to(device)
	model.load_state_dict(ckpt['model'])
	model.eval()

	for n, (x_in, x_out, label) in enumerate(test_iterator):
		x_in = torch.Tensor(x_in[0]).to(device)
		x_out = torch.Tensor(x_out[0]).to(device)
		label_onehot = torch.eye(3, device=device)[label[0]]
		label = label[0]
  
		h_mean, h_alpha, r_mean, r_std, r_out_r, r_out_h = model(x_in, x_out)
		h_std = (model.human_std*label_onehot[:, None]).sum(-1)
		h_sigma = h_std**2
		h_alpha = nn.Softmax(-1)(h_alpha).detach().cpu().numpy()
		h_label = h_alpha.argmax(1)
		label_onehot = label_onehot.detach().cpu().numpy()
		r_out_h = r_out_h.detach().cpu().numpy()
		x_out = x_out.detach().cpu().numpy()
		x_in = x_in.detach().cpu().numpy()
		# object_traj = object_traj[0]

		print(h_sigma.shape)

		for l in range(3):
			for i in range(0,h_mean.shape[0],20):
				sigma_i = torch.diag_embed(model.human_std[0, :3, l]).detach().cpu().numpy()
				pbd_torch.plot_gauss3d(fig_latent, h_mean[i, :3, l].detach().cpu().numpy(), sigma_i,
							color=colors[l], 
							alpha=min(0.5,max(0.1, h_alpha[i, 0])), 
							# alpha=h_alpha[i, l], 
							row=trial+1, col=n+1)

		for i in range(0,h_mean.shape[0],20):
			pbd_torch.plot_gauss3d(fig_latent, r_mean[i, :3].detach().cpu().numpy(), torch.diag_embed(r_std[i, :3]**2).detach().cpu().numpy(),
					color='magenta', alpha=0.3, row=trial+1, col=n+1)
		
			# fig_alpha.add_trace(go.Scatter(dict(x=np.linspace(0,1,h_alpha.shape[0]), y=h_alpha[:, l], line=dict(color=colors[l], dash='solid'))), row=trial+1, col=n+1)
			# fig_alpha.add_trace(go.Scatter(dict(x=np.linspace(0,1,h_alpha.shape[0]), y=label_onehot[:, l], line=dict(color=colors[l], dash='longdash'))), row=trial+1, col=n+1)
		
			# fig_recon.add_trace(go.Scatter3d(dict(x=x_in[label==l][:,0], y=x_in[label==l][:,1], z=x_in[label==l][:,2], mode='lines', line=dict(color=colors[l], width=10, dash='solid'))), row=trial+1, col=n+1)
			# fig_recon.add_trace(go.Scatter3d(dict(x=x_in[label==l][:,3], y=x_in[label==l][:,4], z=x_in[label==l][:,5], mode='lines', line=dict(color=colors[l], width=10, dash='solid'))), row=trial+1, col=n+1)
			# fig_recon.add_trace(go.Scatter3d(dict(x=x_out[label==l][:,0], y=x_out[label==l][:,1], z=x_out[label==l][:,2], mode='lines', line=dict(color=colors[l], width=10, dash='solid'))), row=trial+1, col=n+1)
			# fig_recon.add_trace(go.Scatter3d(dict(x=x_out[label==l][:,3], y=x_out[label==l][:,4], z=x_out[label==l][:,5], mode='lines', line=dict(color=colors[l], width=10, dash='solid'))), row=trial+1, col=n+1)
			# # fig_recon.add_trace(go.Scatter3d(dict(x=object_traj[label==l][:,0], y=object_traj[label==l][:,1], z=object_traj[label==l][:,2], mode='lines', line=dict(color=object_color[l], width=10, dash='solid'))), row=trial+1, col=n+1)
			# fig_recon.add_trace(go.Scatter3d(dict(x=r_out_h[h_label==l][:,0], y=r_out_h[h_label==l][:,1], z=r_out_h[h_label==l][:,2], mode='lines', line=dict(color=colors[l], width=10, dash='longdash'))), row=trial+1, col=n+1)
			# fig_recon.add_trace(go.Scatter3d(dict(x=r_out_h[h_label==l][:,3], y=r_out_h[h_label==l][:,4], z=r_out_h[h_label==l][:,5], mode='lines', line=dict(color=colors[l], width=10, dash='longdash'))), row=trial+1, col=n+1)

		if n==3:
			break
fig_latent.update_layout(height=2000, width=2000)
# fig_alpha.update_layout(height=2000, width=2000, showlegend=False)
# fig_recon.update_layout(height=2000, width=2000, showlegend=False)
# import plotly.io as pio
# pio.write_image(fig_latent, 'latents.pdf',scale=6, width=1080, height=1080)
# pio.write_image(fig_alpha, 'alphas.pdf',scale=6, width=1080, height=1080)
# pio.write_image(fig_recon, 'recons.pdf',scale=6, width=1080, height=1080)

fig_latent.show()
# fig_alpha.show()
# fig_recon.show()