In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division
from copy import deepcopy
import pickle
import uuid
import os

import tensorflow as tf
import numpy as np

from rqst.traj_opt import GDTrajOptimizer, StochTrajOptimizer
from rqst.reward_models import RewardModel, BCRewardModel
from rqst.reward_opt import InteractiveRewardOptimizer
from rqst.dynamics_models import DynamicsModel
from rqst import reward_models
from rqst import utils
from rqst import envs

In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl

%matplotlib inline

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sess = utils.make_tf_session(gpu_mode=False)

In [None]:
env = envs.make_bandit_env()
trans_env = envs.make_bandit_trans_env(env)

In [None]:
expert_policy = env.expert_policy
random_policy = utils.make_random_policy(env)

plot_traj = lambda traj, *args, **kwargs: utils.plot_trajs([traj], *args, **kwargs)

In [None]:
plot_traj(utils.traj_of_rollout(utils.run_ep(expert_policy, env)), env)

In [None]:
plot_traj(utils.traj_of_rollout(utils.run_ep(random_policy, env)), env)

In [None]:
plot_traj(utils.traj_of_rollout(utils.run_ep(expert_policy, trans_env)), trans_env)

In [None]:
plot_traj(utils.traj_of_rollout(utils.run_ep(random_policy, trans_env)), trans_env)

In [None]:
n_demo_rollouts = 100

In [None]:
demo_rollouts = [utils.run_ep(expert_policy, env) for _ in range(n_demo_rollouts)]

In [None]:
#demo_rollouts += [utils.run_ep(expert_policy, trans_env) for _ in range(n_demo_rollouts)] # DEBUG

In [None]:
with open(os.path.join(utils.bandit_data_dir, 'demo_rollouts.pkl'), 'wb') as f:
  pickle.dump(demo_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
demo_perf = utils.compute_perf_metrics(demo_rollouts, env)

In [None]:
demo_perf

In [None]:
n_aug_rollouts = 100

In [None]:
aug_rollouts = demo_rollouts + [utils.run_ep(expert_policy, trans_env) for _ in range(n_aug_rollouts)]

In [None]:
with open(os.path.join(utils.bandit_data_dir, 'aug_rollouts.pkl'), 'wb') as f:
  pickle.dump(aug_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.bandit_data_dir, 'demo_rollouts.pkl'), 'rb') as f:
  demo_rollouts = pickle.load(f)

with open(os.path.join(utils.bandit_data_dir, 'aug_rollouts.pkl'), 'rb') as f:
  aug_rollouts = pickle.load(f)

In [None]:
demo_data = utils.split_rollouts(utils.vectorize_rollouts(demo_rollouts, env.max_ep_len))
aug_data = utils.split_rollouts(utils.vectorize_rollouts(aug_rollouts, env.max_ep_len))
demo_data['obses'].shape, aug_data['obses'].shape

In [None]:
dynamics_model = DynamicsModel(sess, env)

In [None]:
sketch_data_for_reward_model = None
sketch_rollouts_for_reward_model = None

pref_data_for_reward_model = None
pref_logs_for_reward_model = None

In [None]:
demo_data_for_reward_model = demo_data
demo_rollouts_for_reward_model = demo_rollouts

In [None]:
reward_model = BCRewardModel(
    sess,
    env,
    n_rew_nets_in_ensemble=4,
    n_layers=0,
    layer_size=64,
    scope=str(uuid.uuid4()),
    scope_file=os.path.join(utils.bandit_data_dir, 'bc_rew_scope.pkl'),
    tf_file=os.path.join(utils.bandit_data_dir, 'bc_rew.tf'),
    rew_func_input='sa',
    use_discrete_actions=True
    )

In [None]:
reward_model.train(
    demo_data=demo_data_for_reward_model,
    sketch_data=sketch_data_for_reward_model,
    pref_data=pref_data_for_reward_model,
    demo_coeff=1.,
    sketch_coeff=1.,
    iterations=100000,
    ftol=1e-4,
    batch_size=32,
    learning_rate=1e-2,
    val_update_freq=10,
    verbose=True
    )

In [None]:
reward_model.save()

In [None]:
reward_model.load()

In [None]:
reward_model.viz_learned_rew()

In [None]:
traj_optimizer = GDTrajOptimizer(
    sess,
    env,
    reward_model,
    dynamics_model,         
    traj_len=2,
    n_trajs=1,
    prior_coeff=0.,
    diversity_coeff=0.,
    #query_loss_opt='rew_uncertainty',
    #query_loss_opt='max_nov',
    query_loss_opt='max_imi_pol_uncertainty',
    opt_init_obs=True,
    join_trajs_at_init_state=False,
    learning_rate=1e-2,
    query_type='demo'
    )

In [None]:
data = traj_optimizer.run(
    #init_obs=env.default_init_obs,
    init_obs=None,
    iterations=10000,
    ftol=1e-6,
    verbose=True,
    warm_start=False
    )
trajs = data['traj']

In [None]:
trajs

In [None]:
utils.plot_trajs(trajs, env)

In [None]:
rew_eval = reward_models.evaluate_reward_model(
    sess,
    env,
    trans_env,
    reward_model, 
    dynamics_model, 
    offpol_eval_rollouts=sketch_rollouts_for_reward_model,
    n_eval_rollouts=100
    )

In [None]:
rew_eval['perf']

In [None]:
utils.viz_rew_eval(rew_eval, env)

In [None]:
with open(os.path.join(utils.bandit_data_dir, 'opt_rew_eval.pkl'), 'wb') as f:
  pickle.dump(rew_eval, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.bandit_data_dir, 'opt_rew_eval.pkl'), 'rb') as f:
  opt_rew_eval = pickle.load(f)

In [None]:
traj_optimizer = StochTrajOptimizer(
    sess,
    env,
    reward_model,
    dynamics_model,
    traj_len=2,
    rollout_len=1,
    #query_loss_opt='max_nov',
    #query_loss_opt='max_imi_pol_uncertainty',
    query_loss_opt='unif',
    use_rand_policy=False,
    query_type='demo'
    )

In [None]:
data = traj_optimizer.run(
    n_trajs=1,
    n_samples=1000,
    init_obs=None,
    #init_obs=env.default_init_obs,
    verbose=True
    )
trajs = data['traj']

In [None]:
utils.plot_trajs(trajs, env)

In [None]:
offpol_eval_rollouts = aug_rollouts

In [None]:
demo_rollouts_for_reward_model = None

In [None]:
demo_rollouts_for_reward_model = demo_rollouts[:1]

In [None]:
reward_model = BCRewardModel(
    sess,
    env,
    n_rew_nets_in_ensemble=4,
    n_layers=0,
    layer_size=64,
    scope=str(uuid.uuid4()),
    scope_file=os.path.join(utils.bandit_data_dir, 'bc_rew_scope.pkl'),
    tf_file=os.path.join(utils.bandit_data_dir, 'bc_rew.tf'),
    rew_func_input='sa',
    use_discrete_actions=True
    )

In [None]:
dynamics_model = DynamicsModel(sess, env)

In [None]:
rew_optimizer = InteractiveRewardOptimizer(
    sess,
    env, 
    trans_env,
    reward_model, 
    dynamics_model
    )

In [None]:
reward_train_kwargs = {
    'demo_coeff': 1.,
    'sketch_coeff': 1.,
    'iterations': 5000,
    'ftol': 1e-4,
    'batch_size': 32,
    'learning_rate': 1e-2,
    'val_update_freq': 100,
    'verbose': False
    }

dynamics_train_kwargs = {}

gd_traj_opt_init_kwargs = {        
    'traj_len': 2,
    'n_trajs': 1,
    'prior_coeff': 0.,
    'diversity_coeff': 0.,
    'query_loss_opt': 'max_imi_pol_uncertainty',
    'opt_init_obs': True,
    'learning_rate': 1e-2,
    'join_trajs_at_init_state': False
    }

gd_traj_opt_run_kwargs = {
    'init_obs': None,
    'iterations': 10000,
    'ftol': 1e-6,
    'verbose': False,
    'warm_start': False
    }

stoch_traj_opt_init_kwargs = {        
    'traj_len': 2,
    'rollout_len': 1,
    'query_loss_opt': 'unif',
    'use_rand_policy': False
    }

stoch_traj_opt_run_kwargs = {
    'n_trajs': 1,
    'n_samples': 1,
    'init_obs': None,
    'verbose': False
    }

imitation_kwargs = {}

eval_kwargs = {
    'n_eval_rollouts': 100,
    'offpol_eval_rollouts': offpol_eval_rollouts
    }

In [None]:
rew_opt_kwargs = {
    'demo_rollouts': demo_rollouts_for_reward_model,
    'sketch_rollouts': sketch_rollouts_for_reward_model,
    'pref_logs': pref_logs_for_reward_model,
    'rollouts_for_dyn': [],#aug_rollouts,
    'reward_train_kwargs': reward_train_kwargs,
    'dynamics_train_kwargs': dynamics_train_kwargs,
    'imitation_kwargs': imitation_kwargs,
    'eval_kwargs': eval_kwargs,
    'init_train_dyn': False,
    'init_train_rew': True,
    'n_imitation_rollouts_per_dyn_update': 1,
    'n_queries': 20,
    'reward_update_freq': 1,
    'reward_eval_freq': 1,
    'dyn_update_freq': None,
    'verbose': True,
    'warm_start_rew': False,
    'query_type': 'demo'
    }

In [None]:
rew_perf_evals, query_data = rew_optimizer.run(
    traj_opt_cls=GDTrajOptimizer,
    traj_opt_run_kwargs=gd_traj_opt_run_kwargs,
    traj_opt_init_kwargs=gd_traj_opt_init_kwargs,
    **rew_opt_kwargs
    )

In [None]:
rew_perf_evals = rew_optimizer.rew_perf_evals
query_data = rew_optimizer.query_data

In [None]:
plt.plot(rew_perf_evals['n_queries'], rew_perf_evals['rew'])
plt.show()

In [None]:
viz_query_data(query_data, env)

In [None]:
def eval_query_loss(query_loss_opt):
  traj_opt_init_kwargs = deepcopy(gd_traj_opt_init_kwargs)
  traj_opt_init_kwargs['query_loss_opt'] = query_loss_opt
  return rew_optimizer.run(
    traj_opt_cls=GDTrajOptimizer,
    traj_opt_run_kwargs=gd_traj_opt_run_kwargs,
    traj_opt_init_kwargs=traj_opt_init_kwargs,
    **rew_opt_kwargs
    )

In [None]:
n_trials = 3

In [None]:
query_loss_opts = ['max_nov', 'max_imi_pol_uncertainty']

In [None]:
query_loss_evals = []
for query_loss_opt in query_loss_opts:
  query_loss_evals.append([])
  for i in range(n_trials):
    print('%s %d' % (query_loss_opt, i))
    query_loss_evals[-1].append(eval_query_loss(query_loss_opt))

In [None]:
query_loss_eval_data = {
    'query_loss_opts': query_loss_opts,
    'query_loss_evals': query_loss_evals
    }

In [None]:
with open(os.path.join(utils.bandit_data_dir, 'query_loss_eval_data.pkl'), 'wb') as f:
  pickle.dump(query_loss_eval_data, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.bandit_data_dir, 'query_loss_eval_data.pkl'), 'rb') as f:
  query_loss_eval_data = pickle.load(f)

In [None]:
globals().update(query_loss_eval_data)

In [None]:
compute_stoch_eval = lambda: rew_optimizer.run(
  traj_opt_cls=StochTrajOptimizer,
  traj_opt_run_kwargs=stoch_traj_opt_run_kwargs,
  traj_opt_init_kwargs=stoch_traj_opt_init_kwargs,
  **rew_opt_kwargs
  )

In [None]:
stoch_evals = [compute_stoch_eval() for _ in range(n_trials)]

In [None]:
with open(os.path.join(utils.bandit_data_dir, 'stoch_evals.pkl'), 'wb') as f:
  pickle.dump(stoch_evals, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.bandit_data_dir, 'stoch_evals.pkl'), 'rb') as f:
  stoch_evals = pickle.load(f)

In [None]:
stoch_perf_eval = list(zip(*stoch_evals))[0]

In [None]:
stoch_perf_eval[0].keys()

In [None]:
label_of_key = {
    'rew': 'Reward',
    'succ': 'Classification Accuracy in Training Env.',
    'ens_unc': 'Ensemble Uncertainty',
    'xent': 'Cross-Entropy',
    'ent': 'Entropy',
    'acc': 'Classification Accuracy',
    'n_queries': 'Number of Queries',
    'trans_succ': 'Classification Accuracy in Test Env.',
    'trans_rew': 'Log-Likelihood in Test Env.'
    }

In [None]:
label_of_acq_func = {
    'max_nov': 'Max. Novelty',
    'max_imi_pol_uncertainty': 'Max. Uncertainty'
    }

In [None]:
plt.rcParams.update({'font.size': 14})

In [None]:
smooth_win = 3

x_key = 'n_queries'
y_key = 'trans_succ'

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

for query_loss_opt, evals in zip(query_loss_opts, query_loss_evals):
  perf_evals = list(zip(*evals))[0]
  utils.plot_perf_evals(
      perf_evals, 
      x_key, 
      y_key, 
      label=label_of_acq_func[query_loss_opt], 
      smooth_win=smooth_win
      )
  
utils.plot_perf_evals(stoch_perf_eval, x_key, y_key, label='Random (Baseline)', smooth_win=smooth_win)

if y_key in demo_perf:
  plt.axhline(y=np.mean(demo_perf[y_key]), linestyle='--', label='Demonstrations')
  
plt.axhline(y=opt_rew_eval['perf'][y_key], linestyle=':', label='Optimal')
 
plt.legend(loc='best')

plt.savefig(
    os.path.join(utils.bandit_data_dir, 'figures', 'bandit-1.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()