In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division

import pickle
import os
from collections import defaultdict

import numpy as np
import pandas as pd
from statsmodels.stats.anova import AnovaRM
import statsmodels.api as sm

from sensei import utils
from sensei import ase

In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl
%matplotlib inline

mpl.rcParams.update({'font.size': 14})

In [None]:
data_dir = utils.car_human_data_dir
fig_dir = os.path.join(data_dir, 'figures')
if not os.path.exists(fig_dir):
  os.makedirs(fig_dir)
user_ids = [str(i) for i in range(12) if str(i) in os.listdir(data_dir)]

In [None]:
baseline_guide_evals_of_user = {}
for user_id in user_ids:
  user_data_dir = os.path.join(data_dir, user_id)
  baselines_eval_path = os.path.join(user_data_dir, 'guide_evals.pkl')
  with open(baselines_eval_path, 'rb') as f:
    baseline_guide_evals = pickle.load(f)
  baseline_guide_evals_of_user[user_id] = baseline_guide_evals

In [None]:
perfs_of_guide = defaultdict(lambda: defaultdict(list))
for user_id, baseline_guide_evals in baseline_guide_evals_of_user.items():
  for k, v in baseline_guide_evals.items():
    perf = v['perf']
    for metric, val in perf.items():
      perfs_of_guide[k][metric].append(val)

In [None]:
x_guide = 'iden'
y_guide = 'naive'
metric = 'return'
metric_label = utils.label_of_perf_met[metric]
plt.xlabel('%s of %s' % (metric_label, utils.label_of_guide[x_guide]))
if y_guide == 'naive':
  y_label = 'ASE (Our Method)'
else:
  y_label = utils.label_of_guide[y_guide]
plt.ylabel('%s of %s' % (metric_label, y_label))
plt.title('Car Racing')
xs = perfs_of_guide[x_guide][metric]
ys = perfs_of_guide[y_guide][metric]
for x, y in zip(xs, ys):
  plt.plot([x, x], [x, y], linestyle='dotted', color='orange', alpha=0.75, linewidth=2)
if y_guide == 'naive':
  color = 'orange'
else:
  color = utils.color_of_guide[y_guide]
plt.scatter(xs, ys, color=color, alpha=0.75, linewidth=0, s=100)
plt.axes().set_aspect('equal', adjustable='box')
corner = [min(list(xs) + list(ys)), max(list(xs) + list(ys))]
plt.plot(corner, corner, linestyle='--', color='gray')
plt.savefig(os.path.join(fig_dir, 'car-user-study.pdf'), bbox_inches='tight')
plt.show()

In [None]:
n_users = len(baseline_guide_evals_of_user)
depvar = 'response'
subject = 'user_id'
within = 'condition'
metrics = ['return']
for metric in metrics:
  rows = []
  for user_id, baseline_guide_evals in baseline_guide_evals_of_user.items():
    rows.append({subject: user_id, depvar: baseline_guide_evals['iden']['perf'][metric], within: 'unassisted'})
    rows.append({subject: user_id, depvar: baseline_guide_evals['naive']['perf'][metric], within: 'assisted'})
  data = pd.DataFrame(rows)
  aovrm = AnovaRM(data=data, depvar=depvar, subject=subject, within=[within])
  res = aovrm.fit()
  print(res)

In [None]:
questions = [
  "I was able to keep the car on the road",
  "I could anticipate the consequences of my steering actions",
  "I could tell when the car was about to go off road",
  "I could tell when I needed to steer to keep the car on the road",
  "I was often able to determine the car's current position using the picture on the screen",
  "I could tell that the picture on the screen was sometimes delayed",
  "The delay made it harder to perform the task",
  "The delay made it easier to perform the task",
  "The lack of delay made it harder to perform the task",
  "The lack of delay made it easier to perform the task"
]

In [None]:
responses = [
  [[2, 3, 3, 5, 6, 7, 6, 3, 5, 2], [3, 5, 4, 3, 4, 6, 4, 4, 4, 4]],
  [[1, 1, 1, 1, 1, 7, 7, 1, 1, 7], [2, 4, 4, 5, 4, 2, 2, 2, 4, 4]],
  [[4, 3, 3, 3, 5, 6, 6, 2, 2, 7], [6, 5, 4, 6, 6, 6, 6, 2, 2, 6]],
  [[1, 1, 2, 5, 2, 6, 7, 1, 1, 7], [2, 2, 3, 5, 3, 1, 6, 1, 1, 7]],
  [[1, 2, 3, 5, 2, 7, 7, 1, 1, 7], [2, 4, 4, 3, 5, 6, 7, 1, 1, 7]],
  [[1, 3, 3, 5, 5, 7, 7, 1, 1, 7], [4, 3, 5, 6, 5, 5, 7, 1, 1, 7]],
  [[1, 2, 3, 3, 4, 7, 7, 1, 1, 7], [5, 6, 6, 6, 6, 5, 4, 1, 1, 7]],
  [[2, 3, 3, 2, 2, 7, 7, 1, 1, 7], [4, 5, 4, 5, 5, 7, 6, 6, 2, 2]],
  [[2, 2, 2, 3, 4, 7, 7, 1, 1, 7], [4, 2, 3, 4, 3, 3, 5, 1, 1, 6]],
  [[2, 3, 6, 2, 5, 7, 5, 3, 3, 5], [5, 6, 6, 5, 6, 3, 3, 3, 3, 6]],
  [[1, 2, 5, 1, 4, 7, 6, 2, 2, 6], [4, 4, 4, 5, 5, 2, 4, 4, 4, 4]],
  [[2, 2, 3, 3, 2, 7, 7, 1, 3, 6], [4, 4, 5, 5, 5, 5, 4, 4, 4, 5]],
]

In [None]:
n_users = len(responses)
n_phases = len(responses[0])
responses_of_q = [[[np.nan for _ in range(n_users)] for _ in questions] for _ in range(n_phases)]
for phase_idx in range(n_phases):
  for user_idx, user_responses in enumerate(responses):
    for q_idx, response in enumerate(responses[user_idx][phase_idx]):
      responses_of_q[phase_idx][q_idx][user_idx] = response

In [None]:
n_users = len(responses)
depvar = 'response'
subject = 'user_id'
within = 'condition'
for i, q in enumerate(questions):
  rows = []
  for user_id in user_ids:
    user_id = int(user_id)
    rows.append({subject: user_id, depvar: responses_of_q[0][i][user_id], within: 'unassisted'})
    rows.append({subject: user_id, depvar: responses_of_q[1][i][user_id], within: 'assisted'})
  data = pd.DataFrame(rows)
  aovrm = AnovaRM(data=data, depvar=depvar, subject=subject, within=[within])
  res = aovrm.fit()
  p = res.anova_table['Pr > F'].values[0]
  print('%s & $%s%s%s$ & %0.2f & %s%0.2f%s \\\\' % (q, '\\mathbf{' if p < 0.05 else '', utils.discretize_p_value(p), '}' if p < 0.05 else '', np.nanmean(responses_of_q[0][i]), '\\textbf{' if p < 0.05 else '', np.nanmean(responses_of_q[1][i]), '}' if p < 0.05 else ''))

In [None]:
from IPython.core.display import display
from IPython.core.display import HTML
from matplotlib import animation

def outline_img(img, thickness=1, intensity=255):
  img[:thickness, :] = 0
  img[:thickness, :, 0] = intensity
  img[-thickness:, :] = 0
  img[-thickness:, :, 0] = intensity
  img[:, :thickness] = 0
  img[:, :thickness, 0] = intensity
  img[:, -thickness:] = 0
  img[:, -thickness:, 0] = intensity
  return img

gap = (np.ones((64, 1, 3)) * 255).astype('uint8')
def viz_rollout(rollout, guide_name):
  frames = []
  for x in rollout:
    delayed_img = x[-1]['delayed_img']
    pred_img = x[-1]['pred_img']
    img = x[-1]['img']
    if guide_name == 'prac':
      img = outline_img(img)
    elif guide_name == 'iden':
      delayed_img = outline_img(delayed_img)
    elif guide_name == 'naive':
      pred_img = outline_img(pred_img)
    frame = np.concatenate((delayed_img, gap, pred_img, gap, img), axis=1)
    frames.append(frame)
  return frames

def animate_frames(frames):
  fig = plt.figure(figsize=(20, 10))
  plt.axis('off')
  ims = [[plt.imshow(frame, animated=True)] for frame in frames]
  plt.close()
  anim = animation.ArtistAnimation(fig, ims, interval=100, blit=True, repeat_delay=1000)
  return anim

In [None]:
rollout = baseline_guide_evals_of_user['4']['iden']['rollouts'][-1]
iden_frames = viz_rollout(rollout, guide_name='iden')
iden_anim = animate_frames(iden_frames)

In [None]:
display(HTML(iden_anim.to_html5_video()))

In [None]:
iden_anim.save(os.path.join(fig_dir, 'iden.mp4'))

In [None]:
rollout = baseline_guide_evals_of_user['4']['naive']['rollouts'][-1]
naive_frames = viz_rollout(rollout, guide_name='naive')
naive_anim = animate_frames(naive_frames)

In [None]:
display(HTML(naive_anim.to_html5_video()))

In [None]:
naive_anim.save(os.path.join(fig_dir, 'naive.mp4'))

In [None]:
gap = (np.ones((1, 64, 3)) * 255).astype('uint8')
def viz_rollout(rollout, guide_name):
  frames = []
  for x in rollout:
    delayed_img = x[-1]['delayed_img']
    pred_img = x[-1]['pred_img']
    img = x[-1]['img']
    frame = np.concatenate((delayed_img, gap, pred_img, gap, img), axis=0)
    frames.append(frame)
  return frames

In [None]:
rollout = baseline_guide_evals_of_user['4']['iden']['rollouts'][-1]
frames = viz_rollout(rollout, 'iden')

In [None]:
t = 250
gap = (np.ones((64*3+2, 1, 3)) * 255).astype('uint8')
gapped_frames = []
for i in range(t, t+10):
  gapped_frames.append(frames[i])
  if i < t+9:
    gapped_frames.append(gap)
img = np.concatenate(gapped_frames, axis=1)

plt.imshow(img)
plt.axis('off')
plt.savefig(os.path.join(fig_dir, 'car-film-strip.pdf'), bbox_inches='tight', dpi=500)
plt.show()