In [1]:
import sys
import torch
import numpy as np
import math

import matplotlib.pyplot as plt
import seaborn as sns


folder_path = 'models/'
if folder_path not in sys.path:
    sys.path.append(folder_path)

from cnep import CNEP

device = torch.device("cpu")

In [2]:
root_path = '/home/yigit/projects/cnep/outputs/'
test_type = 'ablation/sines_4/'
test_path = '2_4_8/1718015585/'
run_path = root_path + test_type + test_path

data_path = run_path + 'y.pt'

Y = torch.load(data_path, map_location='cpu').to('cpu')
x = torch.linspace(0, 1, Y.shape[1]).unsqueeze(-1)
print(f'x: {x.shape}, Y: {Y.shape}')

num_demos = Y.shape[0]
t_steps = Y.shape[1]

criterion = torch.nn.MSELoss()

x: torch.Size([200, 1]), Y: torch.Size([128, 200, 1])


In [3]:
model_folder = 'saved_models/'
models_path = f'{run_path}{model_folder}'

m2_path, m4_path, m8_path = f'{models_path}cnep2.pt', f'{models_path}cnep4.pt', f'{models_path}cnep8.pt'

batch_size = 1
dx, dy = 1, 1
n_max, m_max = 10, 10
t_steps = Y.shape[1]
device = 'cpu'

cnep2 = CNEP(dx, dy, n_max, m_max, [64,64], num_decoders=2, decoder_hidden_dims=[130, 130], batch_size=batch_size, scale_coefs=True, device=device)
cnep4 = CNEP(dx, dy, n_max, m_max, [64,64], num_decoders=4, decoder_hidden_dims=[64, 64], batch_size=batch_size, scale_coefs=True, device=device)
cnep8 = CNEP(dx, dy, n_max, m_max, [64,64], num_decoders=8, decoder_hidden_dims=[32, 32], batch_size=batch_size, scale_coefs=True, device=device)

cnep2.load_state_dict(torch.load(m2_path, map_location='cpu'))
cnep4.load_state_dict(torch.load(m4_path, map_location='cpu'))
cnep8.load_state_dict(torch.load(m8_path, map_location='cpu'))


# data for testing cnxp
val_obs = torch.zeros((batch_size, n_max, dx+dy), dtype=torch.float32, device=device)
val_tar_x = torch.zeros((batch_size, t_steps, dx), dtype=torch.float32, device=device)
# val_tar_y = torch.zeros((batch_size, t_steps, dy), dtype=torch.float32, device=device)
val_obs_mask = torch.zeros((batch_size, n_max), dtype=torch.bool, device=device)

In [4]:
def find_closest_traj_ind(traj):
    # find the closest trajectory in the dataset
    min_dist = np.inf
    min_ind = -1
    for i, y in enumerate(Y):
        dist = criterion(y, traj[0]).item()
        if dist < min_dist:
            min_dist = dist
            min_ind = i
    return min_ind, min_dist

In [5]:
from tqdm import tqdm
num_tests = 1000

euc_errors = np.zeros((num_tests, 3))

for test_id in tqdm(range(num_tests)):
    val_obs.fill_(0)
    val_tar_x.fill_(0)
    val_obs_mask.fill_(False)
    ###############
    m = torch.randint(1, m_max, (1,)).item()
    val_obs_mask[0, :m] = True
    val_tar_x = torch.linspace(0, 1, t_steps).unsqueeze(0).unsqueeze(-1)
    ###############
    m_ids = torch.randperm(t_steps)[:m]
    for i in range(m):
        y_min, y_max = torch.min(Y[:, m_ids[i], 0]), torch.max(Y[:, m_ids[i], 0])
        val_obs[0, i, 0] = torch.rand(1)*(y_max-y_min) + y_min + 0.1*torch.randn(1)


    with torch.no_grad():
        pred_cnep, gate = cnep2.val(val_obs, val_tar_x, val_obs_mask)
        dec_id = torch.argmax(gate.squeeze(1), dim=-1)
        traj2 = pred_cnep[dec_id, 0, :, :dy]

        pred_cnep, gate = cnep4.val(val_obs, val_tar_x, val_obs_mask)
        dec_id = torch.argmax(gate.squeeze(1), dim=-1)
        traj4 = pred_cnep[dec_id, 0, :, :dy]

        pred_cnep, gate = cnep8.val(val_obs, val_tar_x, val_obs_mask)
        dec_id = torch.argmax(gate.squeeze(1), dim=-1)
        traj8 = pred_cnep[dec_id, 0, :, :dy]

        closest_traj_ind, closest_dist = find_closest_traj_ind(traj2)
        euc_errors[test_id, 0] = closest_dist / t_steps
        closest_traj_ind, closest_dist = find_closest_traj_ind(traj4)
        euc_errors[test_id, 1] = closest_dist / t_steps
        closest_traj_ind, closest_dist = find_closest_traj_ind(traj8)
        euc_errors[test_id, 2] = closest_dist / t_steps

100%|██████████| 1000/1000 [00:12<00:00, 78.90it/s]


In [6]:
print(euc_errors.mean(axis=0))

[1.38781485e-03 1.46971401e-04 9.76262219e-05]
