#### plan
1. function that takes in a list of trajs and returns the traj with the worst performance
2. select id from that traj
3. simulate that id (and starting point) for each model with rendering
4. display a stacked video of all models running that trajectory

In [1]:
%matplotlib inline

import h5py
from IPython.display import HTML
from matplotlib import animation
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy import ndimage
import sys
import tensorflow as tf

import hgail.misc.utils

import hyperparams
import utils
import validate

This call to matplotlib.use() has no effect because the backend has already
been chosen; matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.



In [2]:
def score_trajs(trajs, key='rmse_pos', worst_fn=np.mean):
    scores = np.zeros(len(trajs))
    for i, traj in enumerate(trajs):
        scores[i] = worst_fn(traj[key])
    return scores

In [3]:
relevant_trajs, relevant_labels = utils.load_trajs_labels('../../data/experiments/hgail/imitate/validation/')

In [51]:
scores = score_trajs(relevant_trajs[0])
sorted_score_idxs = list(reversed(np.argsort(scores)))
worst_traj_idx = sorted_score_idxs[-2]
worst_score = scores[worst_traj_idx]
worst_traj = relevant_trajs[0][worst_traj_idx]
worst_egoid = worst_traj['egoid']

In [52]:
print(worst_traj_idx)
print(worst_score)

2036
0.855071528861


In [53]:
# simulate with rendering and return the imgs
def simulate(env, policy, max_steps, env_kwargs=dict(), render_kwargs=dict()):
    x = env.reset(**env_kwargs)
    policy.reset()
    imgs = []
    for step in range(max_steps):
        sys.stdout.write('\rstep: {} / {}'.format(step+1, max_steps))
        img = env.render(**render_kwargs)
        imgs.append(img)
        a, a_info = policy.get_action(x)
        nx, r, done, e_info = env.step(a)
        if done: break
        x = nx
    return imgs

In [54]:
basedir = '../../data/experiments/'
model_labels = [
    'gail',
    'infogail',
    'recurrent_gail',
    'hgail'
]
model_params_filepaths = [
    os.path.join(basedir, 'gail/imitate/log/itr_2000.npz'),
    os.path.join(basedir, 'infogail/imitate/log/itr_1000.npz'),
    os.path.join(basedir, 'gail-recurrent/imitate/log/itr_2000.npz'),
    os.path.join(basedir, 'hgail/imitate/log/itr_1000.npz')
]
model_args_filepaths = [
    os.path.join(basedir, 'gail/imitate/log/args.npz'),
    os.path.join(basedir, 'infogail/imitate/log/args.npz'),
    os.path.join(basedir, 'gail-recurrent/imitate/log/args.npz'),
    os.path.join(basedir, 'hgail/imitate/log/args.npz')
]
n_models = len(model_labels)

# load start timestep for egoid
args_filepath = os.path.join(os.path.join(basedir, 'gail'), 'imitate/log/args.npz')
args = hyperparams.load_args(args_filepath)
ngsim_filename = 'trajdata_i101_trajectories-0750am-0805am.txt'
_, starts = validate.load_egoids(ngsim_filename, args)
worst_start = starts[worst_egoid]

In [55]:
print(worst_start)
print(worst_egoid)

5860
1967


In [56]:
render_map = dict()
max_steps = 200
env_kwargs = dict(
    egoid=worst_egoid, 
    start=worst_start
)
render_kwargs = dict(
    camera_rotation=45.,
    canvas_height=300,
    canvas_width=600
)
for i in range(n_models):
    print('\nrunning: {}'.format(model_labels[i]))
    
    # create session
    tf.reset_default_graph()
    sess = tf.InteractiveSession()
    
    # load args and params
    args = hyperparams.load_args(model_args_filepaths[i])
    # temporary ugly
    if 'recurrent' in model_labels[i]:
        args.recurrent_hidden_dim = 64
    
    params = hgail.misc.utils.load_params(model_params_filepaths[i])
    
    # load env and params
    env, _, _ = utils.build_ngsim_env(args, alpha=0.)
    normalized_env = hgail.misc.utils.extract_normalizing_env(env)
    if normalized_env is not None:
        normalized_env._obs_mean = params['normalzing']['obs_mean']
        normalized_env._obs_var = params['normalzing']['obs_var']
        
    # load policy
    if 'hgail' in model_labels[i]:
        policy = utils.build_hierarchy(args, env)
    else:
        policy = utils.build_policy(args, env)
        
    # initialize variables
    sess.run(tf.global_variables_initializer())
        
    # load params
    if 'hgail' in model_labels[i]:
        for j, level in enumerate(policy):
            level.algo.policy.set_param_values(params[j]['policy'])
        policy = policy[0].algo.policy
    else:
        policy.set_param_values(params['policy'])
        
    # collect imgs
    imgs = simulate(
        env, 
        policy, 
        max_steps=max_steps, 
        env_kwargs=env_kwargs,
        render_kwargs=render_kwargs
    )
    render_map[model_labels[i]] = imgs


running: gail
step: 199 / 200
running: infogail
step: 199 / 200
running: recurrent_gail
step: 199 / 200
running: hgail
step: 199 / 200

In [57]:
# stack the images
gail_imgs = render_map['gail']
infogail_imgs = render_map['infogail']
recurrent_gail_imgs = render_map['recurrent_gail']
hgail_imgs = render_map['hgail']
imgs = [np.concatenate((a,b,c,d), 0) for (a,b,c,d) in zip(
    gail_imgs, 
    infogail_imgs,
    recurrent_gail_imgs,
    hgail_imgs
)]

In [58]:
fig, ax = plt.subplots(figsize=(16,16))
plt.title('gail, infogail, recurrent_gail, hgail')
img = plt.imshow(imgs[0])

def animate(i):
    img.set_data(imgs[i])
    return (img,)

anim = animation.FuncAnimation(
    fig, 
    animate, 
    frames=len(imgs), 
    interval=100, 
    blit=True
)

WriterClass = animation.writers['ffmpeg']
writer = WriterClass(fps=10, metadata=dict(artist='bww'), bitrate=1800)
anim.save('../../data/media/example_10.mp4', writer=writer)

HTML(anim.to_html5_video())