In [1]:
import pickle
from utils import Trajectory
from typing import List
def extract_data_from_trajs(trajs:List[Trajectory]):
    
    traj_ids, obs_flat, gt_flat = [], [], []
    
    for traj_id, traj in enumerate(trajs):
        
        obs_flat.extend(traj.obs_arr)
        gt_flat.extend(traj.gt_arr)
        traj_ids.extend([traj_id]*len(traj.obs_arr))
        
    
    return traj_ids, obs_flat, gt_flat


In [2]:
def preprocess_obs(obs_arr, normalize=True):
    import numpy as np
    mod_arr = []
    
    for obs in obs_arr:
        obs = obs[:,:,-4:]
        obs = np.moveaxis(obs, [0,1], [1,2])
        if normalize:
            obs = obs/255.0
        mod_arr.append(obs)
    
    return mod_arr

In [3]:
def load_data(trajectory_data, normalize=True):
    import numpy as np
    
    traj_ids, obs_flat, gt_flat = extract_data_from_trajs(trajectory_data)
    print("1. Loaded Raw Obs: ", len(obs_flat), {obs.shape for obs in obs_flat}, np.unique(obs_flat))

    proc_obs = preprocess_obs(obs_flat, normalize=normalize)
    print("2. Preprocessed Obs: ", len(proc_obs), {obs.shape for obs in proc_obs}, np.unique(proc_obs))
    
    return traj_ids, obs_flat, gt_flat, proc_obs

In [4]:
def plot_raw(data, free_lim = False, color_map=None, s = 5, alpha = 1):
    from matplotlib import pyplot as plt
    
    if free_lim:
        maxi = np.max(data)*1.05
        mini = np.min(data)*1.05

    plt.figure()

    if free_lim:
        plt.ylim(mini,maxi)
        plt.xlim(mini,maxi)

    plt.scatter(*zip(*data), c = color_map, s = s, alpha=alpha)
    
    #datacursor()
    
    #start, end = ax.get_xlim()
    #ax.xaxis.set_ticks(np.arange(start, end, (end-start)//10))
    
    plt.plot()

In [5]:
def neighborhood_comparison(K, scaled_states, embedding, leaf_size = 40):
    
    from sklearn.neighbors import KDTree
    import numpy as np


    scld_st_kdtree = KDTree(scaled_states, leaf_size=leaf_size)
    embed_kdtree = KDTree(embedding, leaf_size=leaf_size)
    
    _, ind_st_arr = scld_st_kdtree.query(scaled_states, k=K+1)
    _, ind_embd_arr = embed_kdtree.query(embedding, k=K+1)
    
    intersects = []
    matches = []
    for ind in range(ind_st_arr.shape[0]):
        intr = np.intersect1d(ind_st_arr[ind,:], ind_embd_arr[ind,:])
        intersects.append(intr)
        m = intr.shape[0]-1
        if m<0: m=0
        matches.append(m)

    matches = np.array(matches)
    
    return np.mean(matches), matches, intersects, ind_st_arr, ind_embd_arr

    #dist_st, ind_st = scld_st_kdtree.query([scaled_states[0]], k=K+1)
    #dist_embd, ind_embd = embed_kdtree.query([embedding[0]], k=K+1)
    #print(ind_st)
    #print(ind_embd)
    #print(np.intersect1d(ind_st, ind_embd))
