In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division

import pickle
import os
from copy import deepcopy
import uuid

import numpy as np

from sensei import utils
from sensei.user_models import GridWorldNavUser
from sensei.guide_models import GridWorldGuide
from sensei.envs import GridWorldNavEnv, GuideEnv
from sensei import ase

In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl
%matplotlib inline

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sess = utils.make_tf_session(gpu_mode=False)

In [None]:
data_dir = utils.gw_data_dir
fig_dir = os.path.join(data_dir, 'figures')

create habitat env

In [None]:
gw_size = 5
n_goals = 25

In [None]:
n_states = 4*gw_size**2
ground_truth_obs_model = np.eye(n_states)
ground_truth_obs_model = utils.smooth_matrix(ground_truth_obs_model, n_states, eps=0.1)
ground_truth_obs_model = np.log(ground_truth_obs_model)

In [None]:
plt.ylabel('observation')
plt.xlabel('state')
plt.imshow(ground_truth_obs_model)
plt.show()

In [None]:
env = GridWorldNavEnv(
  gw_size=gw_size,
  n_goals=n_goals,
  max_ep_len=25,
  ground_truth_obs_model=ground_truth_obs_model
)

In [None]:
cache_path = os.path.join(data_dir, 'env.pkl')

In [None]:
env.save_to_cache(cache_path)

In [None]:
env.load_from_cache(cache_path)

create guide

In [None]:
ground_truth_obs_model = env.ground_truth_obs_model
dynamics_model = env.make_dynamics_model(eps=1e-9)
q_func = env.Q

In [None]:
init_belief_conf = 1-1e-9

In [None]:
ideal_user_model = GridWorldNavUser(
  env, 
  ground_truth_obs_model, 
  dynamics_model, 
  q_func=q_func,
  init_belief_conf=init_belief_conf
)

In [None]:
internal_dynamics_model = env.make_dynamics_model(eps=0.2)

In [None]:
internal_obs_model = np.zeros(env.ground_truth_obs_model.shape)
idxes = np.arange(0, env.n_states, 1)
internal_obs_model[-idxes-1, idxes] = 1
internal_obs_model = utils.smooth_matrix(internal_obs_model, env.n_states)
internal_obs_model = np.log(internal_obs_model)

In [None]:
plt.ylabel('observation')
plt.xlabel('state')
plt.imshow(internal_obs_model)
plt.show()

In [None]:
user_init_belief_conf = 1-1e-9

In [None]:
user_model = GridWorldNavUser(
  env, 
  internal_obs_model, 
  internal_dynamics_model, 
  q_func=q_func,
  init_belief_conf=user_init_belief_conf
)

In [None]:
guide_env = GuideEnv(env, user_model, n_obs_per_act=1)

In [None]:
iden_guide_policy = lambda obs, info: obs
iden_guide_policy = utils.StutteredPolicy(iden_guide_policy, guide_env.n_obs_per_act)
unif_guide_policy = lambda obs, info: np.random.choice(env.n_obses)
unif_guide_policy = utils.StutteredPolicy(unif_guide_policy, guide_env.n_obs_per_act)

In [None]:
oracle_guide_model = GridWorldGuide(
  sess, 
  env, 
  ground_truth_obs_model, 
  dynamics_model, 
  q_func,
  n_obs_per_act=guide_env.n_obs_per_act,
  internal_dynamics_model=internal_dynamics_model,
  prior_internal_obs_model=internal_obs_model,
  learn_internal_obs_model=False,
  init_belief_conf=init_belief_conf,
  user_init_belief_conf=user_init_belief_conf
)

naive_guide_model = GridWorldGuide(
  sess, 
  env, 
  ground_truth_obs_model, 
  dynamics_model, 
  q_func,
  n_obs_per_act=guide_env.n_obs_per_act,
  internal_dynamics_model=internal_dynamics_model,
  prior_internal_obs_model=ground_truth_obs_model,
  learn_internal_obs_model=False,
  init_belief_conf=init_belief_conf,
  user_init_belief_conf=user_init_belief_conf
)

sanity-check envs, agents

In [None]:
env.reset_init_order()

In [None]:
rollout = utils.run_ep(env.oracle_policy, env, render=True)

In [None]:
env.close()

In [None]:
rollout = utils.run_ep(ideal_user_model, env, render=True)

In [None]:
env.close()

In [None]:
rollout = utils.run_ep(user_model, env, render=True)

In [None]:
env.close()

In [None]:
rollout = utils.run_ep(iden_guide_policy, guide_env, render=True)

In [None]:
guide_env.close()

In [None]:
rollout = utils.run_ep(unif_guide_policy, guide_env, render=True)

In [None]:
guide_env.close()

In [None]:
rollout = utils.run_ep(oracle_guide_model, guide_env, render=True)

In [None]:
guide_env.close()

In [None]:
rollout = utils.run_ep(naive_guide_model, guide_env, render=True)

In [None]:
guide_env.close()

evaluate baselines

In [None]:
guides = {
  'iden': iden_guide_policy,
  'oracle': oracle_guide_model,
  'unif': unif_guide_policy,
  'naive': naive_guide_model
}

In [None]:
baseline_guide_evals = ase.evaluate_baseline_guides(
  sess, 
  guide_env, 
  guides, 
  n_eval_rollouts=100
)

In [None]:
baselines_eval_path = os.path.join(data_dir, 'baselines_eval.pkl')

In [None]:
with open(baselines_eval_path, 'wb') as f:
  pickle.dump(baseline_guide_evals, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(baselines_eval_path, 'rb') as f:
  baseline_guide_evals = pickle.load(f)

In [None]:
for k, v in baseline_guide_evals.items():
  print(k, {x: y for x, y in v['perf'].items() if not x.endswith('_t')})

fit internal obs model

In [None]:
unassisted_train_rollouts = utils.evaluate_policy(
  sess,
  guide_env,
  iden_guide_policy,
  n_eval_rollouts=50
)['rollouts']

In [None]:
unassisted_rollouts_path = os.path.join(data_dir, 'unassisted_rollouts.pkl')

In [None]:
with open(unassisted_rollouts_path, 'wb') as f:
  pickle.dump(unassisted_train_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(unassisted_rollouts_path, 'rb') as f:
  unassisted_train_rollouts = pickle.load(f)

In [None]:
init_train_rollouts = unassisted_train_rollouts

In [None]:
tabular_obs_model_kwargs = {
  'scope_file': os.path.join(data_dir, 'guide_scope.pkl'),
  'tf_file': os.path.join(data_dir, 'guide.tf'),
  'user_init_belief_conf': user_init_belief_conf,
  'obs_params_only': False,
  'prior_coeff': 0.,
  'warm_start': False
}

guide_train_kwargs = {
  'iterations': 2000,
  'ftol': 1e-4,
  'batch_size': 32,
  'learning_rate': 1e-2,
  'val_update_freq': 10,
  'verbose': True
}

guide_model = GridWorldGuide(
  sess, 
  env, 
  ground_truth_obs_model, 
  dynamics_model, 
  q_func,
  n_obs_per_act=guide_env.n_obs_per_act, 
  prior_internal_obs_model=ground_truth_obs_model,
  internal_dynamics_model=internal_dynamics_model,
  tabular_obs_model_kwargs=tabular_obs_model_kwargs,
  learn_internal_obs_model=True,
  init_belief_conf=init_belief_conf,
  user_init_belief_conf=user_init_belief_conf
)

In [None]:
guide_optimizer = ase.InteractiveGuideOptimizer(sess, env, guide_env)

In [None]:
n_reps = 5

In [None]:
train_logs = [guide_optimizer.run(
  guide_model, 
  n_train_batches=20, 
  n_rollouts_per_batch=50, 
  guide_train_kwargs=guide_train_kwargs,
  verbose=True,
  init_train_rollouts=init_train_rollouts,
  n_eval_rollouts=100
) for _ in range(n_reps)]

In [None]:
train_logs_path = os.path.join(data_dir, 'train_logs.pkl')

In [None]:
with open(train_logs_path, 'wb') as f:
  pickle.dump(train_logs, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(train_logs_path, 'rb') as f:
  train_logs = pickle.load(f)

In [None]:
mpl.rcParams.update({'font.size': 14})

In [None]:
for metric in ['succ', 'user_belief_in_true_state']:
  plt.title('2D Navigation')
  plt.xlabel('Number of Training Rollouts')
  plt.ylabel(utils.label_of_perf_met[metric])
  
  for guide_name, guide_eval in baseline_guide_evals.items():
    if guide_name == 'oracle':
      continue
    ys = guide_eval['perf'][metric]
    label = utils.label_of_guide[guide_name]
    if guide_name == 'naive':
      label = 'Naive ASE (Baseline)'
    color = utils.color_of_guide[guide_name]
    linestyle = '-' if guide_name == 'oracle' else '--'
    plt.axhline(y=baseline_guide_evals[guide_name]['perf'][metric], label=label, color=color, linestyle=linestyle, linewidth=2)
  
  guide_name = 'learned'
  label = utils.label_of_guide[guide_name]
  color = utils.color_of_guide[guide_name]
  utils.plot_perf_evals([train_log['guide_perf_evals'] for train_log in train_logs], 'n_train_rollouts', metric, color=color, label=label)
  
  plt.legend(loc='lower right', prop={'size': 12})
  save_path = os.path.join(fig_dir, 'gw_%s_vs_trainsize.pdf' % metric)
  plt.savefig(save_path, bbox_inches='tight')
  plt.show()

In [None]:
guide_model.internal_obs_model.obs_logits_eval = guide_model.internal_obs_model.sess.run(guide_model.internal_obs_model.obs_logits)
learned_obs_model = np.mean(np.exp(guide_model.internal_obs_model.obs_logits_eval), axis=0)

In [None]:
plt.hist(learned_obs_model.ravel(), bins=20)
plt.show()

In [None]:
for i, x in enumerate(np.argmax(learned_obs_model, axis=0)[::-1]):
  print('%d %d' % (i, x))

In [None]:
plt.imshow(learned_obs_model)
plt.show()