In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division

import pickle
import os
from copy import deepcopy
import collections

import numpy as np
import scipy
from matplotlib import animation

from sensei import utils
from sensei.user_models import GridWorldNavUser
from sensei.guide_models import GridWorldGuide
from sensei.envs import GridWorldNavEnv, GuideEnv, HabitatNavEnv
from sensei import ase

In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl
%matplotlib inline

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sess = utils.make_tf_session(gpu_mode=False)

In [None]:
data_dir = utils.hab_data_dir
fig_dir = os.path.join(data_dir, 'figures')

create habitat env

In [None]:
gw_size = 30
n_goals = 10

In [None]:
env_id = 'uNb9QFRL6hY'
dataset = 'mp3d'

min_z = -2
max_z = 1.5
mins = np.array([-100, -100])
maxs = np.array([100, 100])
bbox_xy = (mins, maxs)
bbox_z = (min_z, max_z)

In [None]:
env = HabitatNavEnv(
  gw_size=gw_size,
  n_navigable_poses=1000000,
  radius=1e-1,
  n_goals=n_goals,
  env_id=env_id,
  dataset=dataset,
  verbose=True,
  max_ep_len=501,
  max_hop_dist=0,
  max_lat_dist=0,
  bbox_xy=bbox_xy,
  bbox_z=bbox_z,
  use_cache=True,
  render_mode=True
)

In [None]:
env.ground_truth_obs_model.shape

In [None]:
for i in range(env.ground_truth_obs_model.shape[0]):
  print(i, env.str_of_obs[i], np.exp(env.ground_truth_obs_model[i, :]).mean())

In [None]:
for i in range(env.ground_truth_obs_model.shape[1]):
  print(i, (np.exp(env.ground_truth_obs_model[:, i]) >= (1 / env.ground_truth_obs_model.shape[0])).sum())

In [None]:
plt.ylabel('observation')
plt.xlabel('state')
plt.imshow(env.ground_truth_obs_model)
plt.show()

In [None]:
env.ground_truth_obs_model.shape

create guide

In [None]:
ground_truth_obs_model = env.ground_truth_obs_model
dynamics_model = env.make_dynamics_model(eps=0.2)
q_func = env.Q

In [None]:
init_belief_conf = 1-1e-9

In [None]:
ideal_user_model = GridWorldNavUser(
  env, 
  ground_truth_obs_model, 
  dynamics_model, 
  q_func=q_func,
  init_belief_conf=init_belief_conf
)

In [None]:
internal_dynamics_model = dynamics_model

In [None]:
internal_obs_model = env.ground_truth_obs_model

In [None]:
user_init_belief_conf = 1-1e-9

In [None]:
user_model = GridWorldNavUser(
  env, 
  internal_obs_model, 
  internal_dynamics_model, 
  q_func=q_func,
  init_belief_conf=user_init_belief_conf
)

In [None]:
guide_env = GuideEnv(env, user_model, n_obs_per_act=5)

In [None]:
iden_guide_policy = lambda obs, info: obs
iden_guide_policy = utils.StutteredPolicy(iden_guide_policy, guide_env.n_obs_per_act)
unif_guide_policy = lambda obs, info: np.random.choice(env.n_obses)
unif_guide_policy = utils.StutteredPolicy(unif_guide_policy, guide_env.n_obs_per_act)

In [None]:
oracle_guide_model = GridWorldGuide(
  sess, 
  env, 
  ground_truth_obs_model, 
  dynamics_model, 
  q_func,
  n_obs_per_act=guide_env.n_obs_per_act,
  internal_dynamics_model=internal_dynamics_model,
  prior_internal_obs_model=internal_obs_model,
  learn_internal_obs_model=False,
  init_belief_conf=init_belief_conf,
  user_init_belief_conf=user_init_belief_conf
)

naive_guide_model = GridWorldGuide(
  sess, 
  env, 
  ground_truth_obs_model, 
  dynamics_model, 
  q_func,
  n_obs_per_act=guide_env.n_obs_per_act,
  internal_dynamics_model=internal_dynamics_model,
  prior_internal_obs_model=ground_truth_obs_model,
  learn_internal_obs_model=False,
  init_belief_conf=init_belief_conf,
  user_init_belief_conf=user_init_belief_conf
)

sanity-check envs, agents

In [None]:
render = True

In [None]:
env.set_render_mode(render)

In [None]:
env.reset_init_order()

In [None]:
rollout = utils.run_ep(env.oracle_policy, env, render=render)

In [None]:
env.close()

In [None]:
rollout = utils.run_ep(ideal_user_model, env, render=render)

In [None]:
env.close()

In [None]:
rollout = utils.run_ep(user_model, env, render=render)

In [None]:
env.close()

In [None]:
rollout = utils.run_ep(iden_guide_policy, guide_env, render=render, max_ep_len=20)

In [None]:
frames = [x[-1]['img'] for x in rollout[::guide_env.n_obs_per_act]]
utils.play_nb_vid(frames)

In [None]:
guide_env.close()

In [None]:
rollout = utils.run_ep(unif_guide_policy, guide_env, render=render)

In [None]:
guide_env.close()

In [None]:
rollout = utils.run_ep(oracle_guide_model, guide_env, render=render)

In [None]:
guide_env.close()

In [None]:
rollout = utils.run_ep(naive_guide_model, guide_env, render=render, max_ep_len=50)

In [None]:
guide_env.close()

In [None]:
frames = [x[-1]['img'] for x in rollout[::guide_env.n_obs_per_act]]
utils.play_nb_vid(frames)

In [None]:
img = frames[0]
plt.axis('off')
plt.imshow(img)
plt.savefig(os.path.join(fig_dir, 'hab-viz-init.pdf'), bbox_inches='tight', dpi=500)
plt.show()

In [None]:
habviz_path = os.path.join(fig_dir, 'hab-viz.pkl')

In [None]:
with open(habviz_path, 'rb') as f:
  rollout = pickle.load(f)

In [None]:
with open(habviz_path, 'wb') as f:
  pickle.dump(rollout, f, pickle.HIGHEST_PROTOCOL)

In [None]:
img = rollout[120][-1]['img']

In [None]:
plt.imshow(img[:, :img.shape[1]//3])
plt.axis('off')
plt.savefig(os.path.join(fig_dir, 'hab-viz-1.pdf'), bbox_inches='tight', dpi=500)
plt.show()

In [None]:
plt.imshow(img[:, -img.shape[1]//3:])
plt.axis('off')
plt.savefig(os.path.join(fig_dir, 'hab-viz-2.pdf'), bbox_inches='tight', dpi=500)
plt.show()

make videos

In [None]:
guides = {
  'iden': iden_guide_policy,
  'naive': naive_guide_model
}

In [None]:
vid_path = os.path.join(data_dir, 'baselines_eval_for_vid.pkl')

In [None]:
with open(vid_path, 'rb') as f:
  baseline_guide_evals = pickle.load(f)

In [None]:
baseline_guide_evals = ase.evaluate_baseline_guides(
  sess, 
  guide_env, 
  guides, 
  n_eval_rollouts=4
)

In [None]:
with open(vid_path, 'wb') as f:
  pickle.dump(baseline_guide_evals, f, pickle.HIGHEST_PROTOCOL)

In [None]:
def animate_frames(frames):
  fig = plt.figure(figsize=(20, 10))
  plt.axis('off')
  ims = [[plt.imshow(frame, animated=True)] for frame in frames]
  plt.close()
  anim = animation.ArtistAnimation(fig, ims, interval=1000, blit=True, repeat_delay=1000)
  return anim

In [None]:
max_ep_lens = [max(len(guide_evals['rollouts'][ep_idx][::guide_env.n_obs_per_act]) for guide_evals in baseline_guide_evals.values()) for ep_idx in range(4)]
split_frames = []
for guide_name, guide_evals in baseline_guide_evals.items():
  guide_frames = []
  for i, rollout in enumerate(guide_evals['rollouts']):
    ep_frames = [x[-1]['img'] for x in rollout[::guide_env.n_obs_per_act]]
    if len(ep_frames) < max_ep_lens[i]:
      ep_frames.extend([ep_frames[-1]] * (max_ep_lens[i] - len(ep_frames)))
    guide_frames.extend(ep_frames)
  split_frames.append(guide_frames)

In [None]:
gap = np.ones((10, split_frames[0][0].shape[1], 4)).astype(int) * 255
frames = [np.concatenate((f[0], gap, f[1]), axis=0) for f in zip(*split_frames)]
anim = animate_frames(frames)
anim.save(os.path.join(fig_dir, 'habitat.mp4'))

evaluate baselines

In [None]:
env.set_render_mode(False)

In [None]:
guides = {
  'iden': iden_guide_policy,
  'oracle': oracle_guide_model,
  'unif': unif_guide_policy,
  'naive': naive_guide_model
}

In [None]:
baseline_guide_evals = ase.evaluate_baseline_guides(
  sess, 
  guide_env, 
  guides, 
  n_eval_rollouts=4
)

In [None]:
baselines_eval_path = os.path.join(data_dir, 'baselines_eval.pkl')

In [None]:
with open(baselines_eval_path, 'wb') as f:
  pickle.dump(baseline_guide_evals, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(baselines_eval_path, 'rb') as f:
  baseline_guide_evals = pickle.load(f)

In [None]:
for k, v in baseline_guide_evals.items():
  print(k, {x: y for x, y in v['perf'].items() if not x.endswith('_t')})

In [None]:
perf_metrics = ['succ', 'dist_to_goal', 'rollout_len', 'user_belief_in_true_state']
guide_names = ['iden', 'unif', 'naive', 'oracle']
for guide_name in guide_names:
  guide_eval = baseline_guide_evals[guide_name]
  label = utils.label_of_guide[guide_name]
  line = '%s & ' % label
  for metric in perf_metrics:
    mean = guide_eval['perf'][metric]
    stderr = guide_eval['perf']['%s_stderr' % metric]
    line += '$%0.2f \pm %0.2f$ & ' % (mean, stderr)
  line = line[:-2] + '\\\\'
  print(line)

In [None]:
counts_of_guide = {}
for guide in guides:
  rollouts = baseline_guide_evals[guide]['rollouts']
  obses = []
  for rollout in rollouts:
    for x in rollout:
      obs = x[1]
      if obs is not None:
        obses.append(env.str_of_obs[obs])
  counts = collections.Counter(obses)
  counts_of_guide[guide] = counts

In [None]:
for guide in ['naive', 'oracle']:
  ratios = {}
  for obs, count in counts_of_guide[guide].items():
    ref_count = counts_of_guide['iden'][obs]
    if ref_count == 0:
      ratios[obs] = np.inf
    else:
      ratios[obs] = count / ref_count
  ratios = sorted(list(ratios.items()), key=lambda x: x[1], reverse=True)
  print(guide)
  print('\n'.join('%s %0.2f %d %d' % (k, v, counts_of_guide[guide][k], counts_of_guide['iden'][k]) for k, v in ratios))
  print()