In [1]:
% matplotlib inline

import numpy as np
import pandas as pd
import json
import sys
import os
import matplotlib
#matplotlib.use('Agg') 
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import seaborn as sns
import pdb
import ipdb
import h5py
import importlib
import pickle

# Add root folder to sys path
sys.path.append("../")

# %pylab inline
# inline doesn't give interactive plots
%matplotlib inline 
# %matplotlib notebook
plt.rcParams['figure.figsize'] = (12.0, 12.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'Blues'

sns.set()

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

# Import own modules
# Need to use importlib since module starts with a numerical value.

from load_expert_traj import Expert, ExpertHDF5
from grid_world import create_obstacles

plt.rcParams['figure.figsize'] = (6.0, 6.0) # set default size of plots

  from ._conv import register_converters as _register_converters


In [2]:
def plot_data_n_colors_policy(data, grid_size, policy,
                              bounds=[-100, -0.5, 0.5, 1000],
                              color_map=['black', 'grey', 'red'],
                              save_path=''):
    
    cmap = matplotlib.colors.ListedColormap(color_map)
    norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.imshow(data, cmap=cmap, norm=norm, extent=[0, grid_size, 0, grid_size],
              )

    ax.get_xaxis().set_ticklabels([])
    ax.get_yaxis().set_ticklabels([])
    ax.grid(True)
    ax.autoscale(enable=False)

    for a, m in [(0, '>'), (1, 'v'), (2, '<'), (3, '^')]:
        if len(np.where(policy == a)) > 0:
            x, y = np.where(policy == a)
            # Why this weird math?
            # Original grid map has origin at top left (usual for images).
            # imshow (above) has a property `extent` which allows us to define the origin
            # location when rendering. But ax.scatter (below) doesn't have such an API,
            # it rather defines origin at  bottom left corner hence to convert to top left
            # we do the below math.
            ax.scatter(x=y + 0.5, y=grid_size - 1 - x + 0.5, c='r',
                       s=120, marker=m, edgecolors='none')
    if len(save_path) > 0:
        fig.savefig(save_path)
        
    fig.show()
    
def plot_trajectory(traj_data, grid_size,
                    pred_traj_data=None,
                    obstacles=None,
                    rooms=None,
                    bounds=[-100, -.5, .5, 100],
                    color_map=['black', 'grey', 'red'],
                    save_path='', figsize=(6,6)):
    cmap = matplotlib.colors.ListedColormap(color_map)
    norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(1, 1, 1)
    img_arr = np.ones(grid_size)
    if obstacles is not None:
        for o in obstacles:
            img_arr[o[1], o[0]] = 0.0
    ax.imshow(img_arr, 
              cmap=cmap,
              norm=norm,
              extent=[0, grid_size[1], 0, grid_size[0]],
              interpolation='none', aspect='equal', origin='lower',
              )

    ax.set_xticks(range(grid_size[1]))
    ax.set_yticks(range(grid_size[0]))
    ax.grid(True)    

    print(traj_data)
    for i in range(traj_data.shape[0]):
        x, y = traj_data[i, 0], traj_data[i, 1]
        marker = '.' if i > 0 else '*'
        ax.scatter(x=x + 0.5, y=y + 0.5, c='r',
                       s=120, marker=marker, edgecolors='none')
        
        if pred_traj_data is not None:
            pred_x, pred_y = pred_traj_data[i, 0], pred_traj_data[i, 1]
            ax.scatter(x=pred_x + 0.5, y=pred_y + 0.75, c='g',
                       s=120, marker=marker, edgecolors='none')

    if len(save_path) > 0:
        fig.savefig(save_path)
    
    fig.tight_layout()
    fig.show()
    
def softmax(x):
    new_x = x - np.max(x,axis=1)[:, np.newaxis]
    denom = np.sum(np.exp(new_x), axis = 1)[:, np.newaxis]
    return np.exp(new_x)/denom

In [3]:
def get_all_trajectory_from_sequence(states):
    states_int = states.astype(int)

    last_goal_idx = 0
    traj_list = []
    while last_goal_idx < states.shape[0]:
        curr_goal_idx = last_goal_idx
        while curr_goal_idx < states.shape[0]:
            if (states_int[last_goal_idx, 2:].tolist()
                == states_int[curr_goal_idx, 2:].tolist()):
                curr_goal_idx = curr_goal_idx + 1
            else:
                break
        # we have one trajectory
        traj = states[last_goal_idx:curr_goal_idx, :]
        traj_list.append(traj)
        last_goal_idx = curr_goal_idx
        
    return traj_list

def plot_expert_trajectories_from_h5(expert_h5_path,
                                     num_traj_to_plot=10,
                                     obstacles=None,
                                     rooms=None):
    expert = ExpertHDF5(expert_h5_path, 2)
    # Push data into memory
    expert.push()
    
    expert_data = expert.sample(num_traj_to_plot)
    
    i = 0
    while i < num_traj_to_plot:
        ep_states = expert_data.state[i]
        traj_list = get_all_trajectory_from_sequence(ep_states)
        for traj in traj_list:        
            # get (x, y) states for trajectories
            ep_positions = traj[:, :2][::-1]
            # Plot trajectory
            plot_trajectory(ep_positions,
                            (15, 11),
                            color_map=sns.color_palette("Blues_r"),
                            figsize=(6, 6),
                            obstacles=obstacles,
                            rooms=rooms)
            i = i + 1
            if i >= num_traj_to_plot:
                break

obstacles, rooms = create_obstacles(11, 15, env_name='room',
                                    room_size=3)
# plot_expert_trajectories_from_h5('../h5_trajs/room_trajs/',
#                                  num_traj_to_plot=0,
#                                  obstacles=obstacles,
#                                  rooms=rooms,
#                                  )


In [4]:
results_file = '../results/vae/room_traj/goal_policy_only/discrete_action/lr_001/results_cp_20.pkl'

def plot_pickle_results(results_pkl_path, obstacles, rooms,
                        num_traj_to_plot=10):
    assert os.path.exists(results_pkl_path), \
        'results pickle does not exist {}'.format(results_pkl_path)
    with open(results_pkl_path, 'rb') as results_f:
        results_dict = pickle.load(results_f)
        
        total_traj = len(results_dict['true_traj_state'])
        for i in range(num_traj_to_plot):
            traj_idx = np.random.randint(total_traj)
            #traj_idx = i
            traj_len = results_dict['true_traj_state'][traj_idx].shape[0]
            true_traj, pred_traj = [], []
            if len(results_dict['pred_goal']) > 0:
                print("True goal: {}, Pred goal: {}".format(
                    np.argmax(results_dict['true_goal'][traj_idx]),
                    np.argmax(results_dict['pred_goal'][traj_idx])
                ))
                print("Pred goal softmax: {}".format(np.array_str(
                results_dict['pred_goal'][traj_idx], 
                precision=2,
                suppress_small=True)))
            else:
                print("True goal: {}".format(
                    np.argmax(results_dict['true_goal'][traj_idx])
                ))
                
            if results_dict.get('pred_context') is not None and False:
                print("Pred latent variable:\n{}".format(
                  np.array_str(softmax(results_dict['pred_context'][traj_idx].squeeze()),
                               precision=2, suppress_small=True)))
                          
            for j in range(traj_len):
                x_true = results_dict['true_traj_state'][traj_idx][j, 0, :].tolist()
                x_pred = results_dict['pred_traj_state'][traj_idx][j, 0, :].tolist()
                true_traj.append(x_true)
                pred_traj.append(x_pred)
            '''
            print("True traj: {}, pred traj: {}".format(len(true_traj), len(pred_traj)))
            if x_true != x_pred:
                pdb.set_trace()
            '''
                
            # Plot trajectory
            plot_trajectory(np.array(true_traj),
                            (15, 11),
                            pred_traj_data=np.array(pred_traj),
                            color_map=sns.color_palette("Blues_r"),
                            figsize=(6, 6),
                            obstacles=obstacles,
                            rooms=rooms)

plot_pickle_results(results_file, obstacles, rooms,
                    num_traj_to_plot=10)

> <ipython-input-4-edac861cf768>(9)plot_pickle_results()
-> total_traj = len(results_dict['true_traj_state'])
(Pdb) results_dict.keys()
dict_keys(['true_traj_action', 'pred_goal', 'true_goal', 'pred_context', 'pred_traj_action', 'true_traj_state', 'pred_traj_goal', 'pred_traj_state'])
(Pdb) type(results_dict['true_goal'])
<class 'list'>
(Pdb) len(results_dict['true_goal'])
300
(Pdb) len(results_dict['true_traj_state'])
300
(Pdb) results_dict['true_goal'][0].shape
(1, 50, 2)
(Pdb) results_dict['true_traj_state'].shape
*** AttributeError: 'list' object has no attribute 'shape'
(Pdb) results_dict['true_traj_state'][0].shape
(50, 1, 2)
(Pdb) q


BdbQuit: 