In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division

import pickle
import os

import numpy as np

from sensei import utils
from sensei.user_models import CarUser
from sensei.envs import GuideEnv
from sensei import envs
from sensei import ase
from sensei import dynamics_models
from sensei import encoder_models

In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl
%matplotlib inline

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sess = utils.make_tf_session(gpu_mode=False)

In [None]:
data_dir = utils.car_data_dir
fig_dir = os.path.join(data_dir, 'figures')

create carracing env

In [None]:
encoder_model = encoder_models.load_wm_pretrained_vae(sess)

In [None]:
dynamics_model = dynamics_models.load_wm_pretrained_rnn(encoder_model, sess)

In [None]:
env = envs.CarEnv(encoder_model, dynamics_model, delay=10)

create guide env

In [None]:
env.delay = 10

In [None]:
user_model = CarUser(env)

In [None]:
guide_env = GuideEnv(env, user_model)

In [None]:
naive_guide_model = utils.CarGuidePolicy('naive')

In [None]:
oracle_guide_model = utils.CarGuidePolicy('oracle')

In [None]:
iden_guide_policy = utils.CarGuidePolicy('iden')

sanity-check envs, agents

In [None]:
rollout = utils.run_ep(env.oracle_policy, env, render=True, max_ep_len=200)

In [None]:
frames = [x[-1]['img'] for x in rollout]
utils.play_nb_vid(frames)

In [None]:
env.close()

In [None]:
rollout = utils.run_ep(user_model, env, render=True, max_ep_len=200)

In [None]:
env.close()

In [None]:
rollout = utils.run_ep(iden_guide_policy, guide_env, render=True, max_ep_len=200)

In [None]:
guide_env.close()

In [None]:
rollout = utils.run_ep(naive_guide_model, guide_env, render=True, max_ep_len=200)

In [None]:
guide_env.close()

In [None]:
rollout = utils.run_ep(oracle_guide_model, guide_env, render=True, max_ep_len=200)

In [None]:
guide_env.close()

In [None]:
unifobses_path = os.path.join(data_dir, 'unif_obses.pkl')

In [None]:
with open(unifobses_path, 'rb') as f:
  unif_obses = pickle.load(f)

In [None]:
unif_obses = [x[1] for x in rollout]

In [None]:
with open(unifobses_path, 'wb') as f:
  pickle.dump(unif_obses, f, pickle.HIGHEST_PROTOCOL)

In [None]:
unif_obs_idxes = list(range(len(unif_obses)))
class UnifGuidePolicy(object):
  
  def __init__(self):
    self.img = None
  
  def __call__(self, obs, info):
    self.img = info['img']
    return unif_obses[np.random.choice(unif_obs_idxes)]
  
  def get_action_info(self):
    return {'img': self.img}
  
unif_guide_policy = UnifGuidePolicy()

In [None]:
rollout = utils.run_ep(unif_guide_policy, guide_env, render=True, max_ep_len=200)

In [None]:
guide_env.close()

evaluate baselines

In [None]:
baseline_guides = {
  'oracle': oracle_guide_model,
  'unif': unif_guide_policy
}

In [None]:
baseline_guide_evals = ase.evaluate_baseline_guides(
  sess, 
  guide_env, 
  baseline_guides, 
  n_eval_rollouts=20
)

In [None]:
for k, v in baseline_guide_evals.items():
  print(k, {x: y for x, y in v['perf'].items() if not x.endswith('_t')})

In [None]:
delays = list(range(21))
delays

guides = {
  'iden': iden_guide_policy,
  'naive': naive_guide_model
}

In [None]:
guide_evals_of_delay = {}

In [None]:
def plot_evals(delays, save_path=None):
  for metric in ['return', 'user_belief_in_true_state', 'succ', 'crash']:
    plt.title('Car Racing')
    plt.xlabel(r'Delay $d_{\mathrm{max}}$')
    plt.ylabel(utils.label_of_perf_met[metric])
    for guide_name in guides:
      ys = [guide_evals_of_delay[delay][guide_name]['perf'][metric] for delay in delays]
      yerrs = [guide_evals_of_delay[delay][guide_name]['perf']['%s_stderr' % metric] for delay in delays]
      guide_label = utils.label_of_guide[guide_name]
      color = utils.color_of_guide[guide_name]
      if guide_name == 'naive':
        guide_label = 'ASE (Our Method)'
        color = 'orange'
      plt.errorbar(delays, ys, yerr=yerrs, label=guide_label, color=color, capsize=5)
      
    for guide_name, guide_eval in baseline_guide_evals.items():
      y = guide_eval['perf'][metric]
      guide_label = utils.label_of_guide[guide_name]
      color = utils.color_of_guide[guide_name]
      plt.axhline(y=y, color=color, label=guide_label, linestyle='--')
      
    if metric == 'return':
      plt.ylim([-100, None])
    
    plt.legend(loc='lower left', prop={'size': 10})
    if save_path is not None:
      plt.savefig(os.path.join(save_path, '%s_%s.pdf' % ('carracing', metric)), bbox_inches='tight')
    plt.show()

In [None]:
for i, delay in enumerate(delays):
  env.delay = delay
  user_model = CarUser(env)
  guide_env = GuideEnv(env, user_model)
  guide_evals = ase.evaluate_baseline_guides(
    sess, 
    guide_env, 
    guides, 
    n_eval_rollouts=20
  )
  
  # save memory, discard rollouts
  guide_evals = {k: {'perf': v['perf']} for k, v in guide_evals.items()}
  
  guide_evals_of_delay[delay] = guide_evals
  plot_evals(delays[:i+1])

In [None]:
baselines_eval_path = os.path.join(data_dir, 'baselines_eval.pkl')

In [None]:
with open(baselines_eval_path, 'wb') as f:
  pickle.dump((baseline_guide_evals, guide_evals_of_delay), f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(baselines_eval_path, 'rb') as f:
  baseline_guide_evals, guide_evals_of_delay = pickle.load(f)

In [None]:
mpl.rcParams.update({'font.size': 14})

In [None]:
plot_evals(sorted(guide_evals_of_delay.keys()), save_path=fig_dir)