In [None]:
seed = 123

import numpy as np
np.random.seed(seed)
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import torch.nn as nn
torch.autograd.set_detect_anomaly(True)
import torchvision
from torchvision import datasets, transforms

import hickle as hkl
import pickle as pkl
import pdb
import os

from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture as GMM
import time

from torch.utils.data import DataLoader

from tqdm import tqdm

from copy import deepcopy
import io
import PIL.Image
import multiprocessing as mp
import tikzplotlib

os.chdir("../..")

from src.utils.utils_model import to_var
from src.driver_sensor_model.models_cvae import VAE
from src.utils.data_generator import *
from src.utils.interaction_utils import *

In [None]:
# Load data.
nt = 10
num_states = 7
grid_shape = (20, 30)
  
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

dir = '/data/INTERACTION-Dataset-DR-v1_1/Processed_data_new_goal/driver_sensor_dataset/'

# Test data.
data_file_states = os.path.join(dir, 'states_test.hkl')
data_file_grids = os.path.join(dir, 'label_grids_test.hkl')
data_file_sources = os.path.join(dir, 'sources_test.hkl')

data_test = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt,
                 batch_size=None, shuffle=False, sequence_start_mode='unique', norm=True)

test_loader = torch.utils.data.DataLoader(data_test,
        batch_size=len(data_test), shuffle=False,
        num_workers=mp.cpu_count()-1, pin_memory=True)

In [None]:
for batch_x_test, batch_y_test, sources_test in test_loader:
    batch_x_test, batch_y_test_orig = batch_x_test.to(device), batch_y_test.to(device)
    batch_y_test_orig = batch_y_test_orig.view(batch_y_test_orig.shape[0],1,20,30)
    y_full = unnormalize(batch_x_test.cpu().data.numpy(), nt)
    pos_x = y_full[:,:,0]
    pos_y = y_full[:,:,1]
    orientation = y_full[:,:,2]
    cos_theta = np.cos(orientation)
    sin_theta = np.sin(orientation)
    vel_x = y_full[:,:,3]
    vel_y = y_full[:,:,4]
    speed = np.sqrt(vel_x**2 + vel_y**2)
    acc_x = y_full[:,:,5]
    acc_y = y_full[:,:,6]

    # Project the acceleration on the orientation vector to get longitudinal acceleration.
    dot_prod = acc_x * cos_theta + acc_y * sin_theta
    sign = np.sign(dot_prod)
    acc_proj_x = dot_prod * cos_theta
    acc_proj_y = dot_prod * sin_theta

    acc_proj_sign = sign * np.sqrt(acc_proj_x**2 + acc_proj_y**2)

    batch_size = batch_y_test_orig.shape[0]

In [None]:
# VAE
folder_vae = '/models/cvae'
name_vae = 'lstm_1_Adam_z_100_lr_0.001_rand_123_norm_True_kl_start_0_finish_1.0_center_10000.0_mutual_info_const_alpha_1.5_epochs_30_batch_256'

vae = VAE(
    encoder_layer_sizes_p=[7, 5],
    n_lstms=1,
    latent_size=100,
    decoder_layer_sizes=[256, 600], # used to be 10
    dim=4
    )

vae = vae.cuda()

save_filename = os.path.join(folder_vae, name_vae) + 'epoch_30_vae.pt'

with open(save_filename, 'rb') as f:
    state_dict = torch.load(f)
    vae.load_state_dict(state_dict)

vae.eval()

with torch.no_grad():
    recon_y_inf_most_likely, alpha_p, alpha_p_lin, full_c, z = vae.inference(n=1, c=batch_x_test, mode='most_likely')
    recon_x_inf, _, _, _, _ = vae.inference(n=100, c=batch_x_test, mode='all')
print(recon_y_inf_most_likely.shape, batch_y_test_orig.shape)
print(torch.max(recon_y_inf_most_likely), torch.min(recon_y_inf_most_likely))

grid_shape = (20,30)
recon_y_inf_np = np.reshape(recon_y_inf_most_likely.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))
y_np = np.reshape(batch_y_test_orig.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))
print(recon_y_inf_np.shape, y_np.shape)

recon_y_inf_np_pred = (recon_y_inf_np >= 0.6).astype(float)
recon_y_inf_np_pred[recon_y_inf_np <= 0.4] = 0.0
recon_y_inf_np_pred[np.logical_and(recon_y_inf_np < 0.6, recon_y_inf_np > 0.4)] = 0.5

acc = np.mean(recon_y_inf_np_pred == y_np)
mse = np.mean((recon_y_inf_np - y_np)**2)
print('Acc: ', acc, 'MSE: ', mse)

In [None]:
# Visualize all latent classes.
recon_y_inf_all, _, _, _, _ = vae.inference(n=1, c=batch_x_test, mode='all')
recon_y_inf_all_np = recon_y_inf_all.cpu().data.numpy()

fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(fig, 111,
                 nrows_ncols=(10, 10),
                 axes_pad=0.1,
                 )

for ax, im in zip(grid, recon_y_inf_all_np[:,0]):
    ax.matshow(im, cmap='gray_r', vmin=0, vmax=1)
    ax.set_xticks([], [])
    ax.set_yticks([], [])
plt.savefig('all_latent_classes_vae.png')    
plt.show()

In [None]:
# Visualize the three most likely modes VAE: slowing down.
sample = np.where(np.sum(acc_proj_sign < -1.5, axis=-1) > 0)[0][16]
print('sample: ', sample)                                                      
print('acceleration: ', acc_proj_sign[sample])
print('speed: ', speed[sample])
print('probabilitiies: ', torch.sort(alpha_p[sample])[0])

recon_y_inf_1, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=1)
recon_y_inf_np_1 = np.reshape(recon_y_inf_1.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))

plt.gca().set_aspect('equal', adjustable='box') # 'datalim'
plt.scatter(pos_x[sample], pos_y[sample])
plt.savefig(str(sample) + '_traj_dec.png')
tikzplotlib.save(str(sample) + '_traj_dec.tex')
plt.figure()
plt.matshow(recon_y_inf_np_1[0], cmap='gray_r', vmin=0, vmax=1)
plt.xticks([])
plt.yticks([])
plt.savefig('models/' + str(sample) + '_mode_1_' + str(torch.sort(alpha_p[sample])[0][-1].cpu().data.numpy()) + '_dec.png', pad_inches=0.0)
plt.show()

recon_y_inf_2, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=2)
recon_y_inf_np_2 = np.reshape(recon_y_inf_2.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))

plt.matshow(recon_y_inf_np_2[0], cmap='gray_r', vmin=0, vmax=1)
plt.xticks([])
plt.yticks([])
plt.savefig('models/' + str(sample) + '_mode_2_' + str(torch.sort(alpha_p[sample])[0][-2].cpu().data.numpy()) + '_dec.png', pad_inches=0.0)
plt.show()

plt.matshow(batch_y_test_orig[sample,0].cpu().data.numpy(), cmap='gray_r', vmin=0, vmax=1)
plt.xticks([])
plt.yticks([])
plt.savefig('models/' + str(sample) + '_gt_dec.png', pad_inches=0.0)
plt.show()

In [None]:
# Visualize the most likely modes VAE: constant speed.
print(np.sum(np.sum(np.logical_and(np.abs(acc_proj_sign) < 10, speed > 5.0), axis=-1) > 9), speed.shape)
sample = np.where(np.logical_and(np.sum(np.logical_and(np.abs(acc_proj_sign) < 0.25, speed > 5.0), axis=-1) > 9, alpha_p.cpu().detach().numpy()[:,27] > 0.3))[0][7]
print('sample: ', sample)                                                      
print('acceleration: ', acc_proj_sign[sample])
print('speed: ', speed[sample])
print('probabilitiies: ', torch.sort(alpha_p[sample])[0])

recon_y_inf_1, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=1)
recon_y_inf_np_1 = np.reshape(recon_y_inf_1.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))

plt.gca().set_aspect('equal', adjustable='box') # 'datalim'
plt.scatter(pos_x[sample], pos_y[sample])
plt.savefig(str(sample) + '_traj_const.png')
tikzplotlib.save(str(sample) + '_traj_const.tex')
plt.figure()
plt.matshow(recon_y_inf_np_1[0], cmap='gray_r', vmin=0, vmax=1)
plt.xticks([])
plt.yticks([])
plt.savefig('models/' + str(sample) + '_mode_1_' + str(torch.sort(alpha_p[sample])[0][-1].cpu().data.numpy()) + '_const.png', pad_inches=0.0)
plt.show()

recon_y_inf_2, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=2)
recon_y_inf_np_2 = np.reshape(recon_y_inf_2.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))

plt.matshow(recon_y_inf_np_2[0], cmap='gray_r', vmin=0, vmax=1)
plt.xticks([])
plt.yticks([])
plt.savefig('models/' + str(sample) + '_mode_2_' + str(torch.sort(alpha_p[sample])[0][-2].cpu().data.numpy()) + '_const.png', pad_inches=0.0)
plt.show()

plt.matshow(batch_y_test_orig[sample,0].cpu().data.numpy(), cmap='gray_r', vmin=0, vmax=1)
plt.xticks([])
plt.yticks([])
plt.savefig('models/' + str(sample) + '_gt_const.png', pad_inches=0.0)
plt.show()