In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division
from copy import deepcopy
import pickle
import random
import uuid
import os

import tensorflow as tf
import numpy as np

from rqst.traj_opt import GDTrajOptimizer, StochTrajOptimizer
from rqst.reward_opt import InteractiveRewardOptimizer
from rqst.dynamics_models import MDNRNNDynamicsModel, AbsorptionModel
from rqst.dynamics_models import load_wm_pretrained_rnn
from rqst.encoder_models import VAEModel, load_wm_pretrained_vae
from rqst.reward_models import RewardModel
from rqst import reward_models
from rqst import utils
from rqst import envs

In [None]:
from matplotlib import pyplot as plt
import matplotlib.animation
import matplotlib as mpl

%matplotlib inline

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sess = utils.make_tf_session(gpu_mode=True)

In [None]:
plot_traj = lambda traj, *args, **kwargs: utils.plot_trajs([traj], *args, **kwargs)

In [None]:
env = envs.make_carracing_env(sess, load_reward=True)
random_policy = utils.make_random_policy(env)

In [None]:
#trans_env = envs.make_carracing_trans_env(sess, load_reward=True)
trans_env = None

setup env, sanity check

In [None]:
env_rollout = utils.run_ep(env.expert_policy, env, max_ep_len=100, render=True)

In [None]:
env.close()

In [None]:
trans_rollout = utils.run_ep(env.expert_policy, trans_env, max_ep_len=100, render=True)

In [None]:
trans_env.close()

collect demonstrations, augment with random rollouts

In [None]:
n_slow_rollouts = 10

In [None]:
def slow_policy(obs):
  expert_act = env.expert_policy(obs)
  expert_act[1] = min(0.005, expert_act[1])
  return expert_act

In [None]:
raw_slow_rollouts = [utils.run_ep(slow_policy, env, render=True) for _ in range(n_slow_rollouts)]

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'raw_slow_rollouts.pkl'), 'wb') as f:
  pickle.dump(raw_slow_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'raw_slow_rollouts.pkl'), 'rb') as f:
  raw_slow_rollouts = pickle.load(f)

In [None]:
n_expert_rollouts = 10

In [None]:
raw_expert_rollouts = [utils.run_ep(env.expert_policy, env, render=True) for _ in range(n_expert_rollouts)]

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'raw_expert_rollouts.pkl'), 'wb') as f:
  pickle.dump(raw_expert_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'raw_expert_rollouts.pkl'), 'rb') as f:
  raw_expert_rollouts = pickle.load(f)

In [None]:
n_subopt_rollouts = 10

In [None]:
raw_subopt_rollouts = [utils.run_ep(env.subopt_policy, env, render=True) for _ in range(n_subopt_rollouts)]

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'raw_subopt_rollouts.pkl'), 'wb') as f:
  pickle.dump(raw_subopt_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'raw_subopt_rollouts.pkl'), 'rb') as f:
  raw_subopt_rollouts = pickle.load(f)

In [None]:
#raw_demo_rollouts = raw_subopt_rollouts + raw_expert_rollouts
#raw_demo_rollouts = raw_subopt_rollouts
raw_demo_rollouts = raw_expert_rollouts

In [None]:
demo_perf = utils.compute_perf_metrics(raw_slow_rollouts, env)

In [None]:
demo_perf

In [None]:
rand_perf = utils.compute_perf_metrics(raw_rand_rollouts, env)

In [None]:
rand_perf

In [None]:
n_aug_rollouts = 10

In [None]:
raw_rand_rollouts = [utils.run_ep(random_policy, env, render=True) for _ in range(n_aug_rollouts)]

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'raw_rand_rollouts.pkl'), 'wb') as f:
  pickle.dump(raw_rand_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'raw_rand_rollouts.pkl'), 'rb') as f:
  raw_rand_rollouts = pickle.load(f)

In [None]:
raw_aug_rollouts = raw_expert_rollouts + raw_subopt_rollouts + raw_rand_rollouts

In [None]:
raw_aug_obses = np.array([x[0] for rollout in raw_aug_rollouts for x in rollout])
raw_aug_obs_data = utils.split_rollouts({'obses': raw_aug_obses})
raw_aug_obses.shape

In [None]:
encoder = VAEModel(
    sess,
    env,
    kl_tolerance=0.5,
    #scope=str(uuid.uuid4()),
    scope_file=os.path.join(utils.carracing_data_dir, 'enc_scope.pkl'),
    tf_file=os.path.join(utils.carracing_data_dir, 'enc.tf')
    )

In [None]:
encoder.train(
    raw_aug_obs_data,
    iterations=100000,
    ftol=1e-6,
    learning_rate=1e-3,
    batch_size=32,
    val_update_freq=10,
    verbose=True
    )

In [None]:
encoder.save()

In [None]:
encoder.load()

In [None]:
encoder = load_wm_pretrained_vae(sess, env)

In [None]:
obs = raw_demo_rollouts[-2][50][0]
plt.imshow(obs)
plt.show()

In [None]:
latent = encoder.encode_frame(obs)
latent

In [None]:
latent *= 0

In [None]:
recon = encoder.decode_latent(latent)
plt.imshow(recon)
plt.show()

In [None]:
raw_aug_traj_data = utils.split_rollouts(utils.vectorize_rollouts(
  raw_aug_rollouts, env.max_ep_len, preserve_trajs=True))
raw_aug_traj_data['obses'].shape

In [None]:
abs_model = None

In [None]:
abs_model = AbsorptionModel(
    sess,
    env,
    n_layers=1,
    layer_size=256,
    #scope=str(uuid.uuid4()),
    tf_file=os.path.join(utils.carracing_data_dir, 'abs.tf'),
    scope_file=os.path.join(utils.carracing_data_dir, 'abs_scope.pkl'),
    )

In [None]:
dynamics_model = MDNRNNDynamicsModel(
    encoder,
    sess,
    env,
    scope=str(uuid.uuid4()),
    tf_file=os.path.join(utils.carracing_data_dir, 'dyn.tf'),
    scope_file=os.path.join(utils.carracing_data_dir, 'dyn_scope.pkl'),
    abs_model=abs_model
    )

In [None]:
dynamics_model.train(
    raw_aug_traj_data,
    iterations=200,
    learning_rate=1e-3,
    ftol=1e-6,
    batch_size=32,
    val_update_freq=10,
    verbose=True
    )

In [None]:
dynamics_model.save()

In [None]:
dynamics_model.load()

In [None]:
dynamics_model = load_wm_pretrained_rnn(encoder, sess, env)

In [None]:
rnn_enc_demo_rollouts = utils.rollouts_of_traj_data(utils.rnn_encode_rollouts(
  raw_demo_rollouts, env, encoder, dynamics_model))

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'rnn_enc_demo_rollouts.pkl'), 'wb') as f:
  pickle.dump(rnn_enc_demo_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
rnn_enc_aug_rollouts = utils.rollouts_of_traj_data(utils.rnn_encode_rollouts(
  raw_aug_rollouts, env, encoder, dynamics_model))

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'rnn_enc_aug_rollouts.pkl'), 'wb') as f:
  pickle.dump(rnn_enc_aug_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'rnn_enc_demo_rollouts.pkl'), 'rb') as f:
  demo_rollouts = pickle.load(f)

with open(os.path.join(utils.carracing_data_dir, 'rnn_enc_aug_rollouts.pkl'), 'rb') as f:
  aug_rollouts = pickle.load(f)

In [None]:
env.default_init_obs = demo_rollouts[-2][50][0]

In [None]:
plt.imshow(encoder.decode_latent(env.default_init_obs[:env.n_z_dim]))
plt.show()

In [None]:
env.default_init_obses = [demo_rollouts[-2][50][0]]
env.default_init_obses += random.sample([x[0] for r in demo_rollouts for x in r], 3)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'default_init_obses.pkl'), 'wb') as f:
  pickle.dump(env.default_init_obses, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'default_init_obses.pkl'), 'rb') as f:
  env.default_init_obses = pickle.load(f)

In [None]:
for obs in env.default_init_obses:
  plt.imshow(encoder.decode_latent(obs[:env.n_z_dim]))
  plt.show()

In [None]:
trans_rollouts = utils.rollouts_of_traj_data(utils.rnn_encode_rollouts(
  [trans_rollout], env, encoder, dynamics_model))
trans_env.default_init_obs = trans_rollouts[0][20][0]

In [None]:
plt.imshow(encoder.decode_latent(trans_env.default_init_obs[:env.n_z_dim]))
plt.show()

In [None]:
demo_data = utils.split_rollouts(utils.vectorize_rollouts(demo_rollouts, env.max_ep_len))
aug_data = utils.split_rollouts(utils.vectorize_rollouts(aug_rollouts, env.max_ep_len))
demo_data['obses'].shape, aug_data['obses'].shape

In [None]:
demo_data_for_reward_model = demo_data
demo_rollouts_for_reward_model = demo_rollouts

In [None]:
sketch_data_for_reward_model = aug_data
sketch_rollouts_for_reward_model = aug_rollouts

In [None]:
sketch_data_for_reward_model = demo_data
sketch_rollouts_for_reward_model = demo_rollouts

In [None]:
pref_data_for_reward_model = None
pref_logs_for_reward_model = None

In [None]:
sketch_rollouts_for_reward_model = [[x for x in rollout if x[2] != env.rew_classes[0]] for rollout in sketch_rollouts_for_reward_model]

In [None]:
sketch_data_for_reward_model = utils.split_rollouts(utils.vectorize_rollouts(
  sketch_rollouts_for_reward_model, env.max_ep_len))

In [None]:
autolabels = reward_models.autolabel_prefs(
    aug_rollouts,
    env,
    segment_len=env.max_ep_len+1
    )

In [None]:
pref_logs_for_reward_model = autolabels
pref_data_for_reward_model = utils.split_prefs(autolabels)

In [None]:
reward_init_kwargs = {
    'n_rew_nets_in_ensemble': 4,
    'n_layers': 1,
    'layer_size': 256,
    'scope': str(uuid.uuid4()),
    'scope_file': os.path.join(utils.carracing_data_dir, 'true_rew_scope.pkl'),
    'tf_file': os.path.join(utils.carracing_data_dir, 'true_rew.tf'),
    'rew_func_input': "s'",
    'use_discrete_rewards': True
    }

reward_train_kwargs = {
    'demo_coeff': 1.,
    'sketch_coeff': 1.,
    'iterations': 5000,
    'ftol': 1e-4,
    'batch_size': 32,
    'learning_rate': 1e-3,
    'val_update_freq': 100,
    'verbose': True
    }

In [None]:
data = envs.make_carracing_rew(
    sess, 
    env, 
    sketch_data=sketch_data_for_reward_model,
    reward_init_kwargs=reward_init_kwargs,
    reward_train_kwargs=reward_train_kwargs
    )

In [None]:
env.__dict__.update(data)

In [None]:
trans_env.__dict__.update(data)

In [None]:
reward_model = RewardModel(
    sess,
    env,
    n_rew_nets_in_ensemble=4,
    n_layers=1,
    layer_size=256,
    scope=str(uuid.uuid4()),
    scope_file=os.path.join(utils.carracing_data_dir, 'rew_scope.pkl'),
    tf_file=os.path.join(utils.carracing_data_dir, 'rew.tf'),
    rew_func_input="s'",
    use_discrete_rewards=True
    )

In [None]:
reward_model.train(
    demo_data=demo_data_for_reward_model,
    sketch_data=sketch_data_for_reward_model,
    pref_data=pref_data_for_reward_model,
    demo_coeff=1.,
    sketch_coeff=1.,
    iterations=1500,
    ftol=1e-4,
    batch_size=32,
    learning_rate=1e-3,
    val_update_freq=100,
    verbose=True
    )

In [None]:
reward_model.save()

In [None]:
reward_model.load()

In [None]:
reward_model.init_tf_vars()

In [None]:
reward_model.sketch_data = sketch_data_for_reward_model

In [None]:
reward_model.viz_learned_rew()

In [None]:
def plot_rew_mod_preds(traj, act_seq):
  raws = reward_model.compute_raw_of_transes(traj[:-1, :], act_seq, traj[1:, :])
  probs = np.exp(utils.normalize_logits(raws))
  uncs = reward_model.compute_uncertainty_of_transes(traj[:-1, :], act_seq, traj[1:, :])
  
  for i in range(raws.shape[1]):
    plt.plot(probs[:, i], label=str(env.rew_classes[i]))
  plt.plot(uncs, label='unc')
  plt.legend(loc='best')
  plt.show()

In [None]:
def integrate_accels(accels, min_vel=0, max_vel=1):
  vel = 0
  vels = []
  for acc in accels:
    vel += acc
    vel = max(min_vel, vel)
    vel = min(max_vel, vel)
    vels.append(vel)
  return vels
  
def plot_actions(act_seq):
  accel = act_seq[:, 1] - act_seq[:, 2]
  plt.plot(accel, label='accel')
  plt.plot(integrate_accels(accel), label='vel')
  plt.plot(act_seq[:, 0], label='steer')
  plt.legend(loc='best')
  plt.show()

In [None]:
rollout_idx = 0
raw_rollout = raw_demo_rollouts[rollout_idx]
rollout = demo_rollouts[rollout_idx]
raw_traj = utils.traj_of_rollout(raw_rollout)
traj = utils.traj_of_rollout(rollout)
act_seq = utils.act_seq_of_rollout(rollout)

In [None]:
plot_rew_mod_preds(traj, act_seq)

In [None]:
plot_actions(act_seq)

In [None]:
plot_traj(traj, env, encoder)
plot_traj(raw_traj, env)

In [None]:
traj_len = 50

In [None]:
traj_optimizer = GDTrajOptimizer(
    sess,
    env,
    reward_model,
    dynamics_model,         
    traj_len=traj_len,
    n_trajs=1,
    prior_coeff=0,
    diversity_coeff=0.,
    #query_loss_opt='rew_uncertainty',
    #query_loss_opt='max_rew',
    query_loss_opt='min_rew',
    #query_loss_opt='max_nov',
    opt_init_obs=False,
    join_trajs_at_init_state=False,
    shoot_steps=None,
    learning_rate=1e-2,
    query_type='sketch',
    using_mixact=False
    )

In [None]:
rollout_idx = -2
t = 75
assert t+traj_len-1 <= len(demo_rollouts[rollout_idx])
rollout = demo_rollouts[rollout_idx][t:t+traj_len-1]
init_traj = utils.traj_of_rollout(rollout)
init_act_seq = utils.act_seq_of_rollout(rollout)

In [None]:
init_traj, init_act_seq = utils.rollout_in_dream(
    env.expert_policy,
    env, 
    dynamics_model, 
    init_obs=env.default_init_obs, 
    max_ep_len=(traj_len-1)
    )

In [None]:
utils.plot_trajs([init_traj], env, encoder)

In [None]:
plot_actions(init_act_seq)

In [None]:
plot_rew_mod_preds(init_traj, init_act_seq)

In [None]:
init_act_seq = np.array([[0.2, 0.2, 0] for _ in range(traj_len - 1)])

In [None]:
init_traj = None
init_act_seq = None

In [None]:
init_obs = env.default_init_obses[0] if init_traj is None else init_traj[0, :]
data = traj_optimizer.run(
    init_obs=init_obs,
    #init_obs=None,
    init_traj=init_traj,
    init_act_seq=init_act_seq,
    iterations=5000,
    ftol=1e-4,
    verbose=True,
    warm_start=False,
    init_with_lbfgs=False
    )
trajs = data['traj']
act_seqs = data['act_seq']

In [None]:
# DEBUG
feed_dict = {traj_optimizer.init_obs_ph: init_obs}
trajs, act_seqs = traj_optimizer.sess.run([traj_optimizer.trajs, traj_optimizer.act_seqs], feed_dict=feed_dict)

In [None]:
utils.plot_trajs(trajs, env, encoder)

In [None]:
plot_actions(act_seqs[0])

In [None]:
plot_rew_mod_preds(trajs[0], act_seqs[0])

In [None]:
offpol_eval_rollouts = aug_rollouts

In [None]:
rew_eval = reward_models.evaluate_reward_model(
    sess,
    env,
    trans_env,
    reward_model, 
    dynamics_model, 
    n_eval_rollouts=5,
    offpol_eval_rollouts=offpol_eval_rollouts,
    imitation_kwargs={'plan_horizon': 20, 'n_blind_steps': 1}
    )

In [None]:
rew_eval['perf']

In [None]:
imi_trajs = [utils.traj_of_rollout(rollout) for rollout in rew_eval['rollouts']]

In [None]:
utils.plot_trajs(imi_trajs, env, encoder)

In [None]:
imi_act_seqs = [utils.act_seq_of_rollout(rollout) for rollout in rew_eval['rollouts']]

In [None]:
for act_seq in imi_act_seqs:
  plot_actions(act_seq)

In [None]:
for traj, act_seq in zip(imi_trajs, imi_act_seqs):
  plot_rew_mod_preds(traj, act_seq)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'opt_rew_eval.pkl'), 'wb') as f:
  pickle.dump(rew_eval, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'opt_rew_eval.pkl'), 'rb') as f:
  opt_rew_eval = pickle.load(f)

In [None]:
traj_optimizer = StochTrajOptimizer(
    sess,
    env,
    reward_model,
    dynamics_model,
    traj_len=1000,#env.max_ep_len+1,
    rollout_len=1000,#env.max_ep_len,
    #query_loss_opt='rew_uncertainty'
    query_loss_opt='max_rew',
    #query_loss_opt='min_rew',
    #query_loss_opt='max_nov',
    #query_loss_opt='unif',
    use_rand_policy=False,
    imitation_kwargs={'plan_horizon': 20, 'n_blind_steps': 1},
    query_type='sketch',
    guided_search=False
    )

In [None]:
data = traj_optimizer.run(
    n_trajs=1,
    n_samples=1,
    init_obs=None,
    #init_obs=env.default_init_obs,
    verbose=True
    )
trajs = data['traj']
act_seqs = data['act_seq']

In [None]:
utils.plot_trajs(trajs, env, encoder)

In [None]:
plot_actions(act_seqs[0])

In [None]:
plot_rew_mod_preds(trajs[0], act_seqs[0])

In [None]:
offpol_eval_rollouts = aug_rollouts

In [None]:
demo_rollouts_for_reward_model = None
pref_logs_for_reward_model = None

In [None]:
sketch_rollouts_for_reward_model = demo_rollouts[:1]

In [None]:
reward_model = RewardModel(
    sess,
    env,
    n_rew_nets_in_ensemble=4,
    n_layers=1,
    layer_size=256,
    scope=str(uuid.uuid4()),
    scope_file=os.path.join(utils.carracing_data_dir, 'rew_scope.pkl'),
    tf_file=os.path.join(utils.carracing_data_dir, 'rew.tf'),
    rew_func_input="s'",
    use_discrete_rewards=True
    )

In [None]:
dynamics_model = load_wm_pretrained_rnn(encoder, sess, env)

In [None]:
rew_optimizer = InteractiveRewardOptimizer(
    sess,
    env, 
    trans_env,
    reward_model, 
    dynamics_model
    )

In [None]:
reward_train_kwargs = {
    'demo_coeff': 1.,
    'sketch_coeff': 1.,
    'iterations': 1500,
    'ftol': 1e-4,
    'batch_size': 32,
    'learning_rate': 1e-3,
    'val_update_freq': 100,
    'verbose': False
    }

dynamics_train_kwargs = {
    'iterations': 1,
    'batch_size': 512,
    'learning_rate': 1e-3,
    'ftol': 1e-4,
    'val_update_freq': 100,
    'verbose': False
    }

imitation_kwargs = {
    'plan_horizon': 20,
    'n_blind_steps': 1
    }

eval_kwargs = {
    'n_eval_rollouts': 1,
    'offpol_eval_rollouts': offpol_eval_rollouts
    }

In [None]:
gd_traj_opt_init_kwargs = {        
    'traj_len': 50,
    'n_trajs': 1,
    'prior_coeff': 0.,
    'diversity_coeff': 0.,
    'query_loss_opt': 'rew_uncertainty',
    'opt_init_obs': False,
    'learning_rate': 1e-2,
    'join_trajs_at_init_state': False,
    'shoot_steps': None,
    'using_mixact': True
    }

gd_traj_opt_run_kwargs = {
    'init_obs': env.default_init_obses,
    'iterations': 5000,
    'ftol': 1e-4,
    'verbose': False
    }

query_loss_opts = ['rew_uncertainty', 'max_nov', 'max_rew', 'min_rew']

traj_opt_init_kwargs = []
for query_loss_opt in query_loss_opts:
  kwargs = deepcopy(gd_traj_opt_init_kwargs)
  kwargs['query_loss_opt'] = query_loss_opt
  traj_opt_init_kwargs.append(kwargs)
gd_traj_opt_init_kwargs = traj_opt_init_kwargs
gd_traj_opt_run_kwargs = [gd_traj_opt_run_kwargs] * len(gd_traj_opt_init_kwargs)

In [None]:
stoch_traj_opt_init_kwargs = {        
    'traj_len': env.max_ep_len+1,
    'rollout_len': env.max_ep_len,
    'query_loss_opt': 'unif',
    'use_rand_policy': False,
    'imitation_kwargs': imitation_kwargs
    }

rand_stoch_traj_opt_init_kwargs = deepcopy(stoch_traj_opt_init_kwargs)
rand_stoch_traj_opt_init_kwargs['use_rand_policy'] = True

stoch_traj_opt_run_kwargs = {
    'n_samples': 1,
    'n_trajs': 1,
    'init_obs': None,
    'verbose': False
    }

In [None]:
rew_opt_kwargs = {
    'demo_rollouts': demo_rollouts_for_reward_model,
    'sketch_rollouts': sketch_rollouts_for_reward_model,
    'pref_logs': pref_logs_for_reward_model,
    'rollouts_for_dyn': [],#aug_rollouts,
    'reward_train_kwargs': reward_train_kwargs,
    'dynamics_train_kwargs': dynamics_train_kwargs,
    'imitation_kwargs': imitation_kwargs,
    'eval_kwargs': eval_kwargs,
    'init_train_dyn': False,
    'init_train_rew': True,
    'n_imitation_rollouts_per_dyn_update': 1,
    'n_queries': 10000,
    'reward_update_freq': 1,
    'reward_eval_freq': 5,
    'dyn_update_freq': None,
    'verbose': False,
    'warm_start_rew': False,
    'query_type': 'sketch'
    }

In [None]:
rew_perf_evals, query_data = rew_optimizer.run(
    traj_opt_cls=StochTrajOptimizer,
    traj_opt_run_kwargs=stoch_traj_opt_run_kwargs,
    traj_opt_init_kwargs=stoch_traj_opt_init_kwargs,
    **rew_opt_kwargs
    )

In [None]:
rew_perf_evals = rew_optimizer.rew_perf_evals
query_data = rew_optimizer.query_data

In [None]:
plt.plot(rew_perf_evals['n_queries'], rew_perf_evals['corr'])
plt.show()

In [None]:
def make_eval_func(conf_key):
  def eval_func(conf_val):
    if conf_key == 'prior_coeff':
      traj_opt_init_kwargs = deepcopy(gd_traj_opt_init_kwargs)
      for i in range(len(traj_opt_init_kwargs)):
        traj_opt_init_kwargs[i][conf_key] = conf_val
    elif conf_key == 'query_loss_opt':
      traj_opt_init_kwargs = [kwargs for kwargs in gd_traj_opt_init_kwargs if kwargs['query_loss_opt'] != conf_val]
    return rew_optimizer.run(
      traj_opt_cls=GDTrajOptimizer,
      traj_opt_run_kwargs=gd_traj_opt_run_kwargs,
      traj_opt_init_kwargs=traj_opt_init_kwargs,
      **rew_opt_kwargs
      )
  return eval_func

eval_prior_coeff = make_eval_func('prior_coeff')
eval_query_loss = make_eval_func('query_loss_opt')

In [None]:
n_trials = 3

In [None]:
prior_coeffs = [np.inf, 0., 1., 1e-1, 1e-2, 10.]

In [None]:
prior_coeff_evals = []
for prior_coeff in prior_coeffs:
  prior_coeff_evals.append([])
  for i in range(n_trials):
    print('%f %d' % (prior_coeff, i))
    prior_coeff_evals[-1].append(eval_prior_coeff(prior_coeff))

In [None]:
prior_coeff_eval_data = {
    'prior_coeffs': prior_coeffs,
    'prior_coeff_evals': prior_coeff_evals
    }

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'prior_coeff_eval_data.pkl'), 'wb') as f:
  pickle.dump(prior_coeff_eval_data, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'prior_coeff_eval_data.pkl'), 'rb') as f:
  prior_coeff_eval_data = pickle.load(f)

In [None]:
globals().update(prior_coeff_eval_data)

In [None]:
query_loss_opts = ['rew_uncertainty', 'max_nov', 'max_rew', 'min_rew']

In [None]:
query_loss_evals = []
for query_loss_opt in query_loss_opts:
  query_loss_evals.append([])
  for i in range(n_trials):
    print('%s %d' % (query_loss_opt, i))
    query_loss_evals[-1].append(eval_query_loss(query_loss_opt))

In [None]:
query_loss_eval_data = {
    'query_loss_opts': query_loss_opts,
    'query_loss_evals': query_loss_evals
    }

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'query_loss_eval_data.pkl'), 'wb') as f:
  pickle.dump(query_loss_eval_data, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'query_loss_eval_data.pkl'), 'rb') as f:
  query_loss_eval_data = pickle.load(f)

In [None]:
globals().update(query_loss_eval_data)

In [None]:
compute_stoch_eval = lambda: rew_optimizer.run(
  traj_opt_cls=StochTrajOptimizer,
  traj_opt_run_kwargs=stoch_traj_opt_run_kwargs,
  traj_opt_init_kwargs=stoch_traj_opt_init_kwargs,
  **rew_opt_kwargs
  )

In [None]:
stoch_evals = [compute_stoch_eval() for _ in range(n_trials)]

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'stoch_evals.pkl'), 'wb') as f:
  pickle.dump(stoch_evals, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'stoch_evals.pkl'), 'rb') as f:
  stoch_evals = pickle.load(f)

In [None]:
stoch_perf_eval = list(zip(*stoch_evals))[0]

In [None]:
compute_rand_stoch_eval = lambda: rew_optimizer.run(
  traj_opt_cls=StochTrajOptimizer,
  traj_opt_run_kwargs=stoch_traj_opt_run_kwargs,
  traj_opt_init_kwargs=rand_stoch_traj_opt_init_kwargs,
  **rew_opt_kwargs
  )

In [None]:
rand_stoch_evals = [compute_rand_stoch_eval() for _ in range(n_trials)]

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'rand_stoch_evals.pkl'), 'wb') as f:
  pickle.dump(rand_stoch_evals, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.carracing_data_dir, 'rand_stoch_evals.pkl'), 'rb') as f:
  rand_stoch_evals = pickle.load(f)

In [None]:
cloud_dir = os.path.join(utils.carracing_data_dir, 'cloud')

In [None]:
prior_coeff_evals = [[None for _ in range(n_trials)] for _ in prior_coeffs]
query_loss_evals = [[None for _ in range(n_trials)] for _ in query_loss_opts]
stoch_evals = [None for _ in range(n_trials)]
rand_stoch_evals = [None for _ in range(n_trials)]
unif_evals = [None for _ in range(n_trials)]
for fname in os.listdir(cloud_dir):
  if fname.endswith('.pkl'):
    with open(os.path.join(cloud_dir, fname), 'rb') as f:
      conf, result = pickle.load(f)
      
      if conf[0] in prior_coeffs:
        trial_idx = conf[1]
        if conf[0] == 0:
          if trial_idx < n_trials:
            continue
          else:
            trial_idx -= n_trials
        prior_coeff_evals[prior_coeffs.index(conf[0])][trial_idx] = result
      elif conf[0] in query_loss_opts:
        query_loss_evals[query_loss_opts.index(conf[0])][conf[1]] = result
      elif conf[0] == 'stoch':
        stoch_evals[conf[1]] = result
      elif conf[0] == 'rand_stoch':
        rand_stoch_evals[conf[1]] = result
      elif conf[0] == 'unif':
        unif_evals[conf[1]] = result
      else:
        raise ValueError

In [None]:
prior_coeff_evals = [[x for x in y if x is not None] for y in prior_coeff_evals]
query_loss_evals = [[x for x in y if x is not None] for y in query_loss_evals]

In [None]:
prior_coeffs, prior_coeff_evals = list(zip(*sorted(list(zip(prior_coeffs, prior_coeff_evals)), key=lambda x: x[0])))

In [None]:
stoch_perf_eval = list(zip(*stoch_evals))[0]
rand_stoch_perf_eval = list(zip(*rand_stoch_evals))[0]
unif_perf_eval = list(zip(*unif_evals))[0]

In [None]:
print('\n'.join(list(zip(*prior_coeff_evals[0]))[0][0].keys()))

In [None]:
label_of_key = {
    'rew': 'Reward',
    'succ': 'Success Rate',
    'crash': 'Crash Rate',
    'rolloutlen': 'Trajectory Length',
    'ens_unc': 'Ensemble Uncertainty',
    'xent': 'Cross-Entropy',
    'ent': 'Entropy',
    'acc': 'Classification Accuracy',
    'tpr': 'True Positive Rate',
    'tnr': 'True Negative Rate',
    'fpr': 'False Positive Rate',
    'fnr': 'False Negative Rate',
    'n_queries': 'Number of Queries'
    }

In [None]:
plt.rcParams.update({'font.size': 14})

In [None]:
x_key = 'n_queries'
best_prior_coeff_idx = -1

In [None]:
def plot_baseline_comp(y_key, fig_num):
  smooth_win = 5
  
  plt.title('Car Racing')

  plt.xlabel(label_of_key.get(x_key, x_key))
  plt.ylabel(label_of_key.get(y_key, y_key))

  plt.axhline(
      y=opt_rew_eval['perf'][y_key], 
      linestyle=':', 
      label='Offline Reward Model', 
      color='green',
      linewidth=3
      )
  
  utils.plot_perf_evals(
      rand_stoch_perf_eval, 
      x_key, 
      y_key, 
      label='Random Trajectories (Baseline)', 
      smooth_win=smooth_win, 
      color='teal'
      )

  utils.plot_perf_evals(
      stoch_perf_eval, 
      x_key, 
      y_key, 
      label='Reward-Maximizing Trajectories (Baseline)', 
      smooth_win=smooth_win, 
      color='gray'
      )
  
  prior_coeff_idx = best_prior_coeff_idx
  prior_coeff = prior_coeffs[prior_coeff_idx]
  evals = prior_coeff_evals[prior_coeff_idx]
  perf_evals = list(zip(*evals))[0]
  utils.plot_perf_evals(
      perf_evals, 
      x_key, 
      y_key, 
      label='ReQueST (Ours)', 
      smooth_win=smooth_win, 
      color='orange'
      )

  plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), framealpha=0.)

  plt.savefig(
      os.path.join(utils.carracing_data_dir, 'figures', 'carracing-%d.pdf' % fig_num), 
      dpi=500, 
      bbox_inches='tight',
      transparent=True
      )
  plt.show()

In [None]:
plot_baseline_comp('fpr', 1)

In [None]:
plot_baseline_comp('tnr', 2)

In [None]:
plot_baseline_comp('fnr', 3)

In [None]:
y_key = 'rew'

In [None]:
smooth_win = 5

plt.title('Car Racing')

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Reward Model', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=rand_perf[y_key], 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

utils.plot_perf_evals(
    rand_stoch_perf_eval, 
    x_key, 
    y_key, 
    label='Random Trajectories (Baseline)', 
    smooth_win=smooth_win, 
    color='teal'
    )

utils.plot_perf_evals(
    stoch_perf_eval, 
    x_key, 
    y_key, 
    label='Reward-Maximizing Trajectories (Baseline)', 
    smooth_win=smooth_win, 
    color='gray'
    )

if y_key in demo_perf:
  plt.axhline(
      y=np.mean(demo_perf[y_key]), 
      linestyle='--', 
      label='Demonstrations (Baseline)', 
      color='gray',
      linewidth=3
      )
  
prior_coeff_idx = best_prior_coeff_idx
prior_coeff = prior_coeffs[prior_coeff_idx]
evals = prior_coeff_evals[prior_coeff_idx]
perf_evals = list(zip(*evals))[0]
utils.plot_perf_evals(
    perf_evals, 
    x_key, 
    y_key, 
    label='ReQueST (Ours)', 
    smooth_win=smooth_win, 
    color='orange'
    )

plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.9), framealpha=0., ncol=1)

plt.savefig(
    os.path.join(utils.carracing_data_dir, 'figures', 'carracing-4.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )
plt.show()

In [None]:
smooth_win = 5

y_key = 'rew'

plt.title('Car Racing')

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Reward Model', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=rand_perf[y_key], 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

if y_key in demo_perf:
  plt.axhline(
      y=np.mean(demo_perf[y_key]), 
      linestyle='--', 
      label='Demonstrations (Baseline)', 
      color='gray',
      linewidth=3
      )

prior_coeff_idx = best_prior_coeff_idx
prior_coeff = prior_coeffs[prior_coeff_idx]
evals = prior_coeff_evals[prior_coeff_idx]
perf_evals = list(zip(*evals))[0]
utils.plot_perf_evals(
    perf_evals, 
    x_key, 
    y_key, 
    label='ReQueST (Ours)', 
    smooth_win=smooth_win, 
    color='orange'
    )

plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), framealpha=0., ncol=1)

plt.savefig(
    os.path.join(utils.carracing_data_dir, 'figures', 'carracing-5.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()

In [None]:
smooth_win = 5

y_key = 'rew'

plt.title('Car Racing')

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Reward Model', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=rand_perf[y_key], 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

colors = [
    'blue',
    'lightblue',
    'gray',
    'yellow',
    'orange',
    'red'
    ]
assert len(colors) == len(prior_coeffs)

for prior_coeff, evals, color in zip(prior_coeffs, prior_coeff_evals, colors):
  perf_evals = list(zip(*evals))[0]
  label = '$\lambda = $ %0.2f' % prior_coeff
  label = label.replace('inf', '$\infty$')
  utils.plot_perf_evals(
      perf_evals, 
      x_key, 
      y_key, 
      label=label, 
      smooth_win=smooth_win, 
      color=color
      )

plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), framealpha=0., ncol=2)

plt.savefig(
    os.path.join(utils.carracing_data_dir, 'figures', 'carracing-6.pdf'), 
    dpi=500, 
    bbox_inches='tight', 
    transparent=True
    )
plt.show()

In [None]:
def compute_best_perf(evals, y_key):
  perf_evals = list(zip(*evals))[0]
  mat = utils.make_perf_mat(perf_evals, y_key)
  means = utils.col_means(mat)
  stderrs = utils.col_stderrs(mat)
  if y_key in ['rew', 'succ', 'tpr', 'tnr', 'acc']:
    idx = np.argmax(means)
  elif y_key in ['crash', 'fnr', 'fpr', 'rolloutlen', 'xent']:
    idx = np.argmin(means)
  else:
    raise ValueError
  return means[idx], stderrs[idx]

In [None]:
best_perfs = []
for prior_coeff, evals in zip(prior_coeffs, prior_coeff_evals):
  best_perfs.append((prior_coeff, *compute_best_perf(evals, y_key)))
best_perfs = sorted(best_perfs, key=lambda x: x[0])

best_rand_stoch_perf = compute_best_perf(rand_stoch_evals, y_key)
best_stoch_perf = compute_best_perf(stoch_evals, y_key)

In [None]:
y_key = 'rew'

plt.title('Car Racing')

plt.xlabel('Regularization Constant $\lambda$')
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Reward Model', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=rand_perf['rew'], 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

xs, ys, yerrs = list(zip(*best_perfs))

xs = list(xs)
xs[-1] = 100

plt.errorbar(
    xs,
    y=ys,
    yerr=yerrs,
    marker='o',
    color='orange',
    label='ReQueST (Ours)',
    capsize=5,
    linestyle=''
    )

plt.xscale('symlog', linthreshx=1e-2)

old_xs = deepcopy(xs)
xs = [str(x) for x in xs]
xs[-1] = '$\infty$'
plt.xticks(old_xs, xs)

plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.6), framealpha=0., ncol=1)

plt.savefig(
    os.path.join(utils.carracing_data_dir, 'figures', 'carracing-7.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()

In [None]:
label_of_acq_func = {
  'min_rew': 'Min. Reward',
  'max_rew': 'Max. Reward',
  'max_nov': 'Max. Novelty',
  'rew_uncertainty': 'Max. Uncertainty'
}

In [None]:
smooth_win = 5

y_key = 'rew'

plt.title('Car Racing')

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Reward Model', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=rand_perf['rew'], 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

prior_coeff_idx = best_prior_coeff_idx
perf_evals = list(zip(*prior_coeff_evals[prior_coeff_idx]))[0]
utils.plot_perf_evals(
    perf_evals, 
    x_key, 
    y_key, 
    label='All Acquisition Functions', 
    smooth_win=smooth_win,
    color='orange',
    )

colors = [
    'teal',
    'gray',
    'pink',
    'red'
    ]
assert len(colors) == len(query_loss_opts)

for query_loss_opt, evals, color in zip(query_loss_opts, query_loss_evals, colors):
  perf_evals = list(zip(*evals))[0]
  utils.plot_perf_evals(
      perf_evals, 
      x_key, 
      y_key, 
      label='All - %s' % label_of_acq_func.get(query_loss_opt, query_loss_opt), 
      smooth_win=smooth_win,
      color=color
      )
  
plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), framealpha=0., ncol=2)

plt.savefig(
    os.path.join(utils.carracing_data_dir, 'figures', 'carracing-8.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()

In [None]:
smooth_win = 5

y_key = 'rew'

plt.title('Car Racing')

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Reward Model', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=rand_perf[y_key], 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

utils.plot_perf_evals(
    unif_perf_eval, 
    x_key, 
    y_key, 
    label='Random Trajectories from Dynamics Model (Baseline)', 
    smooth_win=smooth_win*4, 
    color='teal'
    )
  
prior_coeff_idx = best_prior_coeff_idx
prior_coeff = prior_coeffs[prior_coeff_idx]
evals = prior_coeff_evals[prior_coeff_idx]
perf_evals = list(zip(*evals))[0]
utils.plot_perf_evals(
    perf_evals, 
    x_key, 
    y_key, 
    label='ReQueST (Ours)', 
    smooth_win=smooth_win, 
    color='orange'
    )

plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), framealpha=0., ncol=1)

plt.xlim([0, 9000])

plt.savefig(
    os.path.join(utils.carracing_data_dir, 'figures', 'carracing-9.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()