In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division
from collections import defaultdict
from copy import deepcopy
import pickle
import random
import uuid
import os

from scipy.misc import logsumexp
import tensorflow as tf
import numpy as np

from rqst.traj_opt import GDTrajOptimizer, StochTrajOptimizer
from rqst.reward_models import RewardModel, BCRewardModel
from rqst.reward_opt import InteractiveRewardOptimizer
from rqst.dynamics_models import ObsPriorModel
from rqst.encoder_models import VAEModel, IdenModel, MNISTVAEModel
from rqst import reward_models
from rqst import utils
from rqst import envs

In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl

%matplotlib inline

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sess = utils.make_tf_session(gpu_mode=False)

In [None]:
env = envs.make_clfbandit_env()
trans_env = envs.make_clfbandit_trans_env(env)

In [None]:
env.set_expert_policy(IdenModel(sess, env))
trans_env.set_expert_policy(IdenModel(sess, env))

In [None]:
expert_policy = env.expert_policy
random_policy = utils.make_random_policy(env)

plot_traj = lambda traj, *args, **kwargs: utils.plot_trajs([traj], *args, **kwargs)

In [None]:
plot_traj(utils.traj_of_rollout(utils.run_ep(expert_policy, env)), env)

In [None]:
plot_traj(utils.traj_of_rollout(utils.run_ep(random_policy, env)), env)

In [None]:
plot_traj(utils.traj_of_rollout(utils.run_ep(expert_policy, trans_env)), trans_env)

In [None]:
plot_traj(utils.traj_of_rollout(utils.run_ep(random_policy, trans_env)), trans_env)

In [None]:
n_demo_rollouts = 10000

In [None]:
raw_demo_rollouts = [utils.run_ep(expert_policy, env) for _ in range(n_demo_rollouts)]

In [None]:
#raw_demo_rollouts += [utils.run_ep(expert_policy, trans_env) for _ in range(n_demo_rollouts)] # DEBUG

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'raw_demo_rollouts.pkl'), 'wb') as f:
  pickle.dump(raw_demo_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'raw_demo_rollouts.pkl'), 'rb') as f:
  raw_demo_rollouts = pickle.load(f)

In [None]:
n_aug_rollouts = 10000

In [None]:
raw_aug_rollouts = raw_demo_rollouts + [utils.run_ep(expert_policy, trans_env) for _ in range(n_aug_rollouts)]

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'raw_aug_rollouts.pkl'), 'wb') as f:
  pickle.dump(raw_aug_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'raw_aug_rollouts.pkl'), 'rb') as f:
  raw_aug_rollouts = pickle.load(f)

In [None]:
rollouts_of_dig = defaultdict(list)
for r in raw_aug_rollouts[n_demo_rollouts:]:
  rollouts_of_dig[np.argmax(r[0][1])].append(r)
rollouts_of_dig = dict(rollouts_of_dig)

raw_aug_rollouts = raw_aug_rollouts[:n_demo_rollouts]

n_samps_per_trans_dig = 1000
for dig in range(5):
  raw_aug_rollouts += random.sample(rollouts_of_dig[dig], n_samps_per_trans_dig)

In [None]:
raw_aug_obses = np.array([x[0] for rollout in raw_aug_rollouts for x in rollout])
raw_aug_actions = np.array([x[1] for rollout in raw_aug_rollouts for x in rollout])
raw_aug_obs_data = utils.split_rollouts({
    'obses': raw_aug_obses, 
    'actions': raw_aug_actions
    })
raw_aug_obses.shape

In [None]:
encoder = MNISTVAEModel(
    sess,
    env,
    kl_tolerance=utils.inf,
    #scope=str(uuid.uuid4()),
    scope_file=os.path.join(utils.clfbandit_data_dir, 'enc_scope.pkl'),
    tf_file=os.path.join(utils.clfbandit_data_dir, 'enc.tf')
    )

In [None]:
encoder.train(
    raw_aug_obs_data,
    iterations=100000,
    ftol=1e-6,
    learning_rate=1e-3,
    batch_size=32,
    val_update_freq=1000,
    verbose=True
    )

In [None]:
encoder.save()

In [None]:
encoder.load()

In [None]:
encoder = IdenModel(sess, env)

In [None]:
obs = raw_aug_rollouts[-6][0][0]
plt.imshow(obs[:, :, 0], cmap=mpl.cm.binary)
plt.show()

In [None]:
plt.hist(obs.ravel())
plt.show()

In [None]:
latent = encoder.encode_frame(obs)
latent

In [None]:
latent *= 0.

In [None]:
std = np.exp(-.1)#1
latent = np.random.normal(0, std, env.n_z_dim)

In [None]:
recon = encoder.decode_latent(latent)
plt.imshow(recon[:, :, 0], cmap=mpl.cm.binary)
plt.show()

In [None]:
plt.hist(recon.ravel())
plt.show()

In [None]:
env.set_expert_policy(encoder)

In [None]:
trans_env.set_expert_policy(encoder)

In [None]:
demo_rollouts = utils.map_frames(raw_demo_rollouts, encoder.encode_batch_frames, batch=True)

In [None]:
aug_rollouts = utils.map_frames(raw_aug_rollouts, encoder.encode_batch_frames, batch=True)

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'demo_rollouts.pkl'), 'wb') as f:
  pickle.dump(demo_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'aug_rollouts.pkl'), 'wb') as f:
  pickle.dump(aug_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
labels = [np.argmax(x[0][1]) for x in demo_rollouts]
xs = [x[0][0][0] for x in demo_rollouts]
ys = [x[0][0][1] for x in demo_rollouts]

In [None]:
plt.scatter(xs, ys, c=labels, alpha=0.5)
plt.show()

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'demo_rollouts.pkl'), 'rb') as f:
  demo_rollouts = pickle.load(f)

with open(os.path.join(utils.clfbandit_data_dir, 'aug_rollouts.pkl'), 'rb') as f:
  aug_rollouts = pickle.load(f)

In [None]:
demo_data = utils.split_rollouts(utils.vectorize_rollouts(demo_rollouts, env.max_ep_len))
aug_data = utils.split_rollouts(utils.vectorize_rollouts(aug_rollouts, env.max_ep_len))
demo_data['obses'].shape, aug_data['obses'].shape

In [None]:
demo_perf = utils.compute_perf_metrics(demo_rollouts, env)

In [None]:
demo_perf

In [None]:
dynamics_model = ObsPriorModel(sess, env, encoder=encoder)

In [None]:
sketch_data_for_reward_model = None
sketch_rollouts_for_reward_model = None

pref_data_for_reward_model = None
pref_logs_for_reward_model = None

In [None]:
demo_rollouts_for_reward_model = demo_rollouts

In [None]:
demo_data_for_reward_model = utils.split_rollouts(utils.vectorize_rollouts(demo_rollouts_for_reward_model, env.max_ep_len))

In [None]:
reward_model = BCRewardModel(
    sess,
    env,
    n_rew_nets_in_ensemble=4,
    n_layers=1,
    layer_size=128,
    #scope=str(uuid.uuid4()),
    scope_file=os.path.join(utils.clfbandit_data_dir, 'bc_rew_scope.pkl'),
    tf_file=os.path.join(utils.clfbandit_data_dir, 'bc_rew.tf'),
    rew_func_input='sa',
    use_discrete_actions=True
    )

In [None]:
reward_model.train(
    demo_data=demo_data_for_reward_model,
    sketch_data=sketch_data_for_reward_model,
    pref_data=pref_data_for_reward_model,
    demo_coeff=1.,
    sketch_coeff=1.,
    iterations=100000,
    ftol=1e-4,
    batch_size=32,
    learning_rate=1e-3,
    val_update_freq=100,
    verbose=True
    )

In [None]:
reward_model.save()

In [None]:
reward_model.load()

In [None]:
reward_model.demo_data = demo_data_for_reward_model
reward_model.sketch_data = sketch_data_for_reward_model

In [None]:
reward_model.viz_learned_rew()

In [None]:
env.default_init_obs = demo_rollouts[0][0][0]

In [None]:
plt.imshow(encoder.decode_latent(env.default_init_obs)[:, :, 0], cmap=mpl.cm.binary)
plt.show()

In [None]:
act_seq = np.zeros(env.n_act_dim)
act_seq[4] = 1
act_seq[9] = 1
act_seq = act_seq[np.newaxis, np.newaxis, :] * utils.inf

In [None]:
traj_optimizer = GDTrajOptimizer(
    sess,
    env,
    reward_model,
    dynamics_model,         
    traj_len=2,
    n_trajs=1,
    prior_coeff=1e-1,
    diversity_coeff=0.,
    #query_loss_opt='max_nov',
    query_loss_opt='max_imi_pol_uncertainty',
    #query_loss_opt='pref_uncertainty',
    #query_loss_opt='max_rew',
    opt_init_obs=True,
    opt_act_seq=True,
    join_trajs_at_init_state=True,
    learning_rate=1e-2,
    query_type='demo',
    shoot_steps=None
    )

In [None]:
data = traj_optimizer.run(
    #init_obs=env.default_init_obs,
    #act_seq=act_seq,
    init_obs=None,
    iterations=5000,
    ftol=1e-4,
    verbose=True,
    warm_start=False
    )
trajs = data['traj']
act_seqs = data['act_seq']

In [None]:
act_seqs = np.array(act_seqs)[:, 0, :]
act_seqs -= logsumexp(act_seqs, axis=1, keepdims=True)
act_seqs = np.exp(act_seqs)

In [None]:
sorted(list(range(act_seqs.shape[1])), key=lambda i: np.abs(act_seqs[0, i] - act_seqs[1, i]))

In [None]:
act_seqs, np.argmax(act_seqs, axis=1)

In [None]:
votes = reward_model.vote_on_actions(trajs[0])[0, :, :]
agg_votes = np.mean(votes, axis=0)
agg_votes, np.argmax(agg_votes)

In [None]:
np.argmax(votes, axis=1), np.mean(np.var(votes, axis=0)), -(votes * np.log(votes)).sum(axis=1).mean(), utils.np_ens_disag(votes[np.newaxis, :, :])[0]

In [None]:
utils.plot_trajs(trajs, env, encoder)

In [None]:
rollout = [(trajs[0][0], None, None, None, None, None)]
demo_rollouts_for_reward_model.append(rollout)

In [None]:
rew_eval = reward_models.evaluate_reward_model(
    sess,
    env,
    trans_env,
    reward_model, 
    dynamics_model, 
    offpol_eval_rollouts=aug_rollouts,
    n_eval_rollouts=100
    )

In [None]:
rew_eval['perf']

In [None]:
utils.viz_rew_eval(rew_eval, env, encoder)

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'opt_rew_eval.pkl'), 'wb') as f:
  pickle.dump(rew_eval, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'opt_rew_eval.pkl'), 'rb') as f:
  opt_rew_eval = pickle.load(f)

In [None]:
traj_optimizer = StochTrajOptimizer(
    sess,
    env,
    reward_model,
    dynamics_model,
    traj_len=2,
    rollout_len=1,
    #query_loss_opt='max_nov',
    query_loss_opt='max_imi_pol_uncertainty',
    #query_loss_opt='unif',
    #query_loss_opt='pref_uncertainty',
    #query_loss_opt='max_rew',
    use_rand_policy=False,
    query_type='demo'
    )

In [None]:
data = traj_optimizer.run(
    #act_seq=act_seq,
    n_trajs=1,
    n_samples=10,
    init_obs=None,
    verbose=True
    )
trajs = data['traj']
act_seqs = data['act_seq']

In [None]:
act_seqs = np.array(act_seqs)[:, 0, :]
act_seqs -= logsumexp(act_seqs, axis=1, keepdims=True)
act_seqs = np.exp(act_seqs)

In [None]:
sorted(list(range(act_seqs.shape[1])), key=lambda i: np.abs(act_seqs[0, i] - act_seqs[1, i]))

In [None]:
act_seqs, np.argmax(act_seqs, axis=1)

In [None]:
votes = reward_model.vote_on_actions(trajs[0])[0, :, :]
agg_votes = np.mean(votes, axis=0)
agg_votes, np.argmax(agg_votes)

In [None]:
np.argmax(votes, axis=1), np.mean(np.var(votes, axis=0)), -(votes * np.log(votes)).sum(axis=1).mean(), utils.np_ens_disag(votes[np.newaxis, :, :])[0]

In [None]:
utils.plot_trajs(trajs, env, encoder)

In [None]:
offpol_eval_rollouts = random.sample(aug_rollouts, 1000)

In [None]:
demo_rollouts_for_reward_model = demo_rollouts[:10]

In [None]:
reward_model = BCRewardModel(
    sess,
    env,
    n_rew_nets_in_ensemble=4,
    n_layers=1,
    layer_size=128,
    scope=str(uuid.uuid4()),
    scope_file=os.path.join(utils.clfbandit_data_dir, 'bc_rew_scope.pkl'),
    tf_file=os.path.join(utils.clfbandit_data_dir, 'bc_rew.tf'),
    rew_func_input='sa',
    use_discrete_actions=True
    )

In [None]:
dynamics_model = ObsPriorModel(sess, env, encoder=encoder)

In [None]:
rew_optimizer = InteractiveRewardOptimizer(
    sess,
    env, 
    trans_env,
    reward_model, 
    dynamics_model
    )

In [None]:
reward_train_kwargs = {
    'demo_coeff': 1.,
    'sketch_coeff': 1.,
    'iterations': 5000,
    'ftol': 1e-4,
    'batch_size': 32,
    'learning_rate': 1e-2,
    'val_update_freq': 100,
    'verbose': False
    }

dynamics_train_kwargs = {}

imitation_kwargs = {}

eval_kwargs = {
    'n_eval_rollouts': 100,
    'offpol_eval_rollouts': offpol_eval_rollouts
    }

In [None]:
gd_traj_opt_init_kwargs = {        
    'traj_len': 2,
    'n_trajs': 1,
    'prior_coeff': 1e-1,
    'diversity_coeff': 0.,
    'query_loss_opt': 'unif',
    'opt_init_obs': True,
    'opt_act_seq': True,
    'learning_rate': 1e-2,
    'join_trajs_at_init_state': True,
    'shoot_steps': None
    }

gd_traj_opt_run_kwargs = {
    'init_obs': None,
    'iterations': 5000,
    'ftol': 1e-4,
    'verbose': False,
    'warm_start': False
    }

In [None]:
query_loss_opts = ['max_imi_pol_uncertainty', 'max_nov']
prior_coeffs = [1e-1, 1e-2]

traj_opt_init_kwargs = []
for query_loss_opt, prior_coeff in zip(query_loss_opts, prior_coeffs):
  kwargs = deepcopy(gd_traj_opt_init_kwargs)
  kwargs['query_loss_opt'] = query_loss_opt
  kwargs['prior_coeff'] = prior_coeff
  traj_opt_init_kwargs.append(kwargs)
gd_traj_opt_init_kwargs = traj_opt_init_kwargs
gd_traj_opt_run_kwargs = [gd_traj_opt_run_kwargs] * len(gd_traj_opt_init_kwargs)

In [None]:
stoch_traj_opt_init_kwargs = {        
    'traj_len': 2,
    'rollout_len': 1,
    'query_loss_opt': 'unif',
    'use_rand_policy': False
    }

stoch_traj_opt_run_kwargs = {
    'n_trajs': 10,
    'n_samples': 10,
    'init_obs': None,
    'verbose': False
    }

In [None]:
rew_opt_kwargs = {
    'demo_rollouts': demo_rollouts_for_reward_model,
    'sketch_rollouts': sketch_rollouts_for_reward_model,
    'pref_logs': pref_logs_for_reward_model,
    'rollouts_for_dyn': [],#aug_rollouts,
    'reward_train_kwargs': reward_train_kwargs,
    'dynamics_train_kwargs': dynamics_train_kwargs,
    'imitation_kwargs': imitation_kwargs,
    'eval_kwargs': eval_kwargs,
    'init_train_dyn': False,
    'init_train_rew': True,
    'n_imitation_rollouts_per_dyn_update': 1,
    'n_queries': 2000,
    'reward_update_freq': 5,
    'reward_eval_freq': 5,
    'dyn_update_freq': None,
    'verbose': False,
    'warm_start_rew': False,
    'query_type': 'demo'
    }

In [None]:
rew_perf_evals, query_data = rew_optimizer.run(
    traj_opt_cls=GDTrajOptimizer,
    traj_opt_run_kwargs=gd_traj_opt_run_kwargs,
    traj_opt_init_kwargs=gd_traj_opt_init_kwargs,
    **rew_opt_kwargs
    )

In [None]:
rew_perf_evals = rew_optimizer.rew_perf_evals
query_data = rew_optimizer.query_data

In [None]:
plt.plot(rew_perf_evals['n_queries'], rew_perf_evals['trans_succ'])
plt.show()

In [None]:
utils.viz_query_data(query_data, env, encoder=encoder)

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'query_data.pkl'), 'wb') as f:
  pickle.dump(query_data, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'query_data.pkl'), 'rb') as f:
  query_data = pickle.load(f)

In [None]:
def make_eval_func(conf_key):
  def eval_func(conf_val):
    if conf_key == 'prior_coeff':
      traj_opt_init_kwargs = deepcopy(gd_traj_opt_init_kwargs)
      for i in range(len(traj_opt_init_kwargs)):
        traj_opt_init_kwargs[i][conf_key] = conf_val[i]
    elif conf_key == 'query_loss_opt':
      traj_opt_init_kwargs = [kwargs for kwargs in gd_traj_opt_init_kwargs if kwargs['query_loss_opt'] != conf_val]
    return rew_optimizer.run(
      traj_opt_cls=GDTrajOptimizer,
      traj_opt_run_kwargs=gd_traj_opt_run_kwargs,
      traj_opt_init_kwargs=traj_opt_init_kwargs,
      **rew_opt_kwargs
      )
  return eval_func

eval_prior_coeff = make_eval_func('prior_coeff')
eval_query_loss = make_eval_func('query_loss_opt')

In [None]:
n_trials = 3

In [None]:
prior_coeffs = list(zip([0, 1e-2, 1e-1, 1], [0, 1e-3, 1e-2, 1e-1]))
prior_coeffs

In [None]:
prior_coeff_evals = []
for prior_coeff in prior_coeffs:
  prior_coeff_evals.append([])
  for i in range(n_trials):
    print('%s %d' % (str(prior_coeff), i))
    prior_coeff_evals[-1].append(eval_prior_coeff(prior_coeff))

In [None]:
prior_coeff_eval_data = {
    'prior_coeffs': prior_coeffs,
    'prior_coeff_evals': prior_coeff_evals
    }

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'prior_coeff_eval_data.pkl'), 'wb') as f:
  pickle.dump(prior_coeff_eval_data, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(utils.clfbandit_data_dir, 'prior_coeff_eval_data.pkl'), 'rb') as f:
  prior_coeff_eval_data = pickle.load(f)

In [None]:
globals().update(prior_coeff_eval_data)

In [None]:
query_loss_opts = ['max_imi_pol_uncertainty', 'max_nov']

In [None]:
query_loss_evals = []
for query_loss_opt in query_loss_opts:
  query_loss_evals.append([])
  for i in range(n_trials):
    print('%s %d' % (query_loss_opt, i))
    query_loss_evals[-1].append(eval_query_loss(query_loss_opt))

In [None]:
query_loss_eval_data = {
    'query_loss_opts': query_loss_opts,
    'query_loss_evals': query_loss_evals
    }

In [None]:
with open(os.path.join(local_bandit_data_dir, 'query_loss_eval_data.pkl'), 'wb') as f:
  pickle.dump(query_loss_eval_data, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(local_bandit_data_dir, 'query_loss_eval_data.pkl'), 'rb') as f:
  query_loss_eval_data = pickle.load(f)

In [None]:
globals().update(query_loss_eval_data)

In [None]:
compute_stoch_eval = lambda: rew_optimizer.run(
  traj_opt_cls=StochTrajOptimizer,
  traj_opt_run_kwargs=stoch_traj_opt_run_kwargs,
  traj_opt_init_kwargs=stoch_traj_opt_init_kwargs,
  **rew_opt_kwargs
  )

In [None]:
stoch_evals = [compute_stoch_eval() for _ in range(n_trials)]

In [None]:
with open(os.path.join(local_bandit_data_dir, 'stoch_evals.pkl'), 'wb') as f:
  pickle.dump(stoch_evals, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(local_bandit_data_dir, 'stoch_evals.pkl'), 'rb') as f:
  stoch_evals = pickle.load(f)

In [None]:
cloud_dir = os.path.join(utils.clfbandit_data_dir, 'cloud')

In [None]:
prior_coeff_evals = [[None for _ in range(n_trials)] for _ in prior_coeffs]
query_loss_evals = [[None for _ in range(n_trials)] for _ in query_loss_opts]
stoch_evals = [None for _ in range(n_trials)]
unif_evals = [None for _ in range(n_trials)]
demo_init_evals = [None for _ in range(n_trials)]
for fname in os.listdir(cloud_dir):
  if fname.endswith('.pkl'):
    with open(os.path.join(cloud_dir, fname), 'rb') as f:
      conf, result = pickle.load(f)
      if conf[0] in prior_coeffs:
        if int(fname.split('.pkl')[0]) in [24, 25, 26]:
          demo_init_evals[conf[1]] = result
        else:
          prior_coeff_evals[prior_coeffs.index(conf[0])][conf[1]] = result
      elif conf[0] in query_loss_opts:
        query_loss_evals[query_loss_opts.index(conf[0])][conf[1]] = result
      elif conf[0] == 'stoch':
        stoch_evals[conf[1]] = result
      elif conf[0] == 'unif':
        unif_evals[conf[1]] = result
      else:
        raise ValueError

In [None]:
prior_coeffs, prior_coeff_evals = list(zip(*sorted(list(zip(prior_coeffs, prior_coeff_evals)), key=lambda x: x[0][0])))

In [None]:
stoch_perf_eval = list(zip(*stoch_evals))[0]
unif_perf_eval = list(zip(*unif_evals))[0]
demo_init_perf_eval = list(zip(*demo_init_evals))[0]

In [None]:
print('\n'.join(list(zip(*prior_coeff_evals[0]))[0][0].keys()))

In [None]:
label_of_key = {
    'rew': 'Reward',
    'succ': 'Classification Accuracy in Training Env.',
    'ens_unc': 'Ensemble Uncertainty',
    'xent': 'Cross-Entropy',
    'ent': 'Entropy',
    'acc': 'Classification Accuracy',
    'n_queries': 'Number of Queries',
    'trans_succ': 'Classification Accuracy in Test Env.',
    'trans_rew': 'Log-Likelihood in Test Env.'
    }

In [None]:
plt.rcParams.update({'font.size': 14})

In [None]:
x_key = 'n_queries'
best_prior_coeff_idx = 2

In [None]:
smooth_win = 5

plt.title('MNIST')

y_key = 'trans_succ'

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Classifier', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=0.2, 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

utils.plot_perf_evals(
    stoch_perf_eval, 
    x_key, 
    y_key, 
    label='Random Digits from Training Env. (Baseline)', 
    smooth_win=smooth_win, 
    color='teal'
    )

prior_coeff_idx = best_prior_coeff_idx
prior_coeff = prior_coeffs[prior_coeff_idx]
evals = prior_coeff_evals[prior_coeff_idx]
perf_evals = list(zip(*evals))[0]
utils.plot_perf_evals(
    perf_evals, 
    x_key, 
    y_key, 
    label='ReQueST (Ours)', 
    smooth_win=smooth_win, 
    color='orange'
    )

plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), framealpha=0.)

plt.savefig(
    os.path.join(utils.clfbandit_data_dir, 'figures', 'clfbandit-4.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()

In [None]:
y_key = 'trans_succ'

In [None]:
smooth_win = 5

plt.title('MNIST')

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Classifier', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=0.2, 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

colors = [
    'teal',
    'gray',
    'orange',
    'red'
    ]
assert len(colors) == len(prior_coeffs)

for prior_coeff, evals, color in zip(prior_coeffs, prior_coeff_evals, colors):
  perf_evals = list(zip(*evals))[0]
  label = '$\lambda = $ %0.2f' % prior_coeff[0]
  label = label.replace('inf', '$\infty$')
  utils.plot_perf_evals(
      perf_evals, 
      x_key, 
      y_key, 
      label=label, 
      smooth_win=smooth_win, 
      color=color
      )

plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.6), framealpha=0., ncol=2)

plt.savefig(
    os.path.join(utils.clfbandit_data_dir, 'figures', 'clfbandit-6.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )
plt.show()

In [None]:
def compute_best_perf(evals, y_key):
  perf_evals = list(zip(*evals))[0]
  mat = utils.make_perf_mat(perf_evals, y_key)
  means = utils.col_means(mat)
  stderrs = utils.col_stderrs(mat)
  if y_key in ['rew', 'succ', 'trans_succ', 'acc']:
    idx = np.argmax(means)
  elif y_key in ['xent']:
    idx = np.argmin(means)
  else:
    raise ValueError
  return means[idx], stderrs[idx]

In [None]:
best_perfs = []
for prior_coeff, evals in zip(prior_coeffs, prior_coeff_evals):
  best_perfs.append((prior_coeff[0], *compute_best_perf(evals, y_key)))
best_perfs = sorted(best_perfs, key=lambda x: x[0])

best_stoch_perf = compute_best_perf(stoch_evals, y_key)

In [None]:
y_key = 'trans_succ'

plt.title('MNIST')

plt.xlabel('Regularization Constant $\lambda$')
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Classifier', 
    color='green',
    linewidth=3
    )
'''
plt.axhline(
    y=0.2, 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )
'''
xs, ys, yerrs = list(zip(*best_perfs))
plt.errorbar(
    xs, 
    y=ys, 
    yerr=yerrs, 
    marker='o', 
    color='orange', 
    label='ReQueST (Ours)',
    capsize=5,
    linestyle=''
    )

plt.xscale('symlog', linthreshx=1e-2)

plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.5), framealpha=0., ncol=1)

plt.savefig(
    os.path.join(utils.clfbandit_data_dir, 'figures', 'clfbandit-7.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()

In [None]:
label_of_acq_func = {
  'min_rew': 'Min. Reward',
  'max_rew': 'Max. Reward',
  'max_nov': 'Max. Novelty',
  'rew_uncertainty': 'Max. Uncertainty',
  'max_imi_pol_uncertainty': 'Max. Uncertainty'
}

In [None]:
smooth_win = 5

plt.title('MNIST')

y_key = 'trans_succ'

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Classifier', 
    color='green',
    linewidth=3
    )
'''
plt.axhline(
    y=0.2, 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )
'''
prior_coeff_idx = best_prior_coeff_idx
perf_evals = list(zip(*prior_coeff_evals[prior_coeff_idx]))[0]
utils.plot_perf_evals(
    perf_evals, 
    x_key, 
    y_key, 
    label='All Acquisition Functions', 
    smooth_win=smooth_win,
    color='orange'
    )

colors = [
  'teal',
  'gray'
]
assert len(colors) == len(query_loss_opts)

for query_loss_opt, evals, color in zip(query_loss_opts, query_loss_evals, colors):
  perf_evals = list(zip(*evals))[0]
  label = 'All - %s' % label_of_acq_func.get(query_loss_opt, query_loss_opt)
  utils.plot_perf_evals(
      perf_evals, 
      x_key, 
      y_key, 
      label=label, 
      smooth_win=smooth_win,
      color=color
      )
    
plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), framealpha=0., ncol=1)

plt.ylim([0.5, None])

plt.savefig(
    os.path.join(utils.clfbandit_data_dir, 'figures', 'clfbandit-8.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()

In [None]:
qs_of_query_loss_opt = {}
for query_loss_opt in ['max_imi_pol_uncertainty', 'max_nov']:
  if query_loss_opt == 'max_imi_pol_uncertainty':
      init_idx = 10
  elif query_loss_opt == 'max_nov':
    init_idx = 11
  else:
    raise ValueError

  qs = []
  for r in query_data['demo_rollouts'][init_idx::2]:
    qs.append(r[0][0])
  qs = np.array(qs)
  qs = encoder.decode_batch_latents(qs)
  
  qs_of_query_loss_opt[query_loss_opt] = qs

In [None]:
def plot_dig_tile(query_loss_opt, fig_num):
  n_col = 10
  n_row = 10
  _, axs = plt.subplots(n_row, n_col, figsize=(20, 20))
  axs = axs.flatten()
  qs = qs_of_query_loss_opt[query_loss_opt]
  chunk_size = qs.shape[0] // len(axs)
  for i, ax in enumerate(axs):
    ax.imshow(qs[i * chunk_size, :, :, 0], cmap=mpl.cm.binary)
    ax.grid(False)
    ax.axis('off')
    
  plt.savefig(
      os.path.join(utils.clfbandit_data_dir, 'figures', 'clfbandit-%d.pdf' % fig_num), 
      bbox_inches='tight', 
      dpi=500,
      transparent=True
      )
  
  plt.show()

In [None]:
plot_dig_tile('max_imi_pol_uncertainty', 10)

In [None]:
plot_dig_tile('max_nov', 11)

In [None]:
smooth_win = 5

y_key = 'trans_rew'

plt.title('MNIST')

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Classifier', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=np.log(0.2), 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

utils.plot_perf_evals(
    unif_perf_eval, 
    x_key, 
    y_key, 
    label='Random Samples from VAE Prior (Baseline)', 
    smooth_win=smooth_win, 
    color='teal'
    )

utils.plot_perf_evals(
    demo_init_perf_eval, 
    x_key, 
    y_key, 
    label='ReQueST (Ours)', 
    smooth_win=smooth_win, 
    color='orange'
    )
  
plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), framealpha=0.)

plt.xlim([0, 500])
plt.ylim([-5, None])

plt.savefig(
    os.path.join(utils.clfbandit_data_dir, 'figures', 'clfbandit-12.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()

In [None]:
smooth_win = 5

y_key = 'succ'

plt.title('MNIST')

plt.xlabel(label_of_key.get(x_key, x_key))
plt.ylabel(label_of_key.get(y_key, y_key))

plt.axhline(
    y=opt_rew_eval['perf'][y_key], 
    linestyle=':', 
    label='Offline Classifier', 
    color='green',
    linewidth=3
    )

plt.axhline(
    y=0.2, 
    linestyle=':', 
    label='Random Policy (Baseline)', 
    color='gray',
    linewidth=3
    )

utils.plot_perf_evals(
    stoch_perf_eval, 
    x_key, 
    y_key, 
    label='Random Digits from Training Env. (Baseline)', 
    smooth_win=smooth_win, 
    color='teal'
    )

prior_coeff_idx = best_prior_coeff_idx
prior_coeff = prior_coeffs[prior_coeff_idx]
evals = prior_coeff_evals[prior_coeff_idx]
perf_evals = list(zip(*evals))[0]
utils.plot_perf_evals(
    perf_evals, 
    x_key, 
    y_key, 
    label='ReQueST (Ours)', 
    smooth_win=smooth_win, 
    color='orange'
    )
  
plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.7), framealpha=0.)

plt.ylim([0.5, None])

plt.savefig(
    os.path.join(utils.clfbandit_data_dir, 'figures', 'clfbandit-13.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )

plt.show()

In [None]:
distrn = np.zeros(10)
distrn[5:] = 0.2

xs = list(range(10))

plt.xlabel('Class')
plt.ylabel('Frequency in Training Env.')
plt.bar(x=xs, height=distrn, color='orange')
plt.xticks(xs, xs)
plt.ylim([0, 1])
plt.savefig(
    os.path.join(utils.clfbandit_data_dir, 'figures', 'clfbandit-14.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )
plt.show()

In [None]:
distrn = np.zeros(10)
distrn[:5] = 0.2

xs = list(range(10))

plt.xlabel('Class')
plt.ylabel('Frequency in Transfer Env.')
plt.bar(x=xs, height=distrn, color='orange')
plt.xticks(xs, xs)
plt.ylim([0, 1])
plt.savefig(
    os.path.join(utils.clfbandit_data_dir, 'figures', 'clfbandit-15.pdf'), 
    dpi=500, 
    bbox_inches='tight',
    transparent=True
    )
plt.show()