In [None]:
from __future__ import division
import pickle
import random
import os
import math
import types
import uuid
import time
from copy import copy
from collections import defaultdict, Counter

import numpy as np
from gym.envs.classic_control import rendering
from pyglet.window import key as pygkey
import gym
from gym import spaces, wrappers

import dill
import tempfile
import tensorflow as tf
from tensorflow.contrib import rnn
import zipfile

import baselines.common.tf_util as U

from baselines import logger
from baselines.common.schedules import LinearSchedule
from baselines.deepq.models import mlp as deepq_mlp
from baselines.deepq import learn as deepq_learn
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from baselines.deepq.simple import ActWrapper

from scipy.special import logsumexp

from pyquaternion import Quaternion

import rospy
from nav_msgs.msg import Odometry
from sensor_msgs.msg import Image
from geometry_msgs.msg import TransformStamped, Twist, Vector3
from std_msgs.msg import Empty

from transforms3d.euler import quat2euler

In [None]:
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
import matplotlib as mpl

In [None]:
data_dir = os.path.join('data', 'quadrotor')

In [None]:
latest_vicon_state = None
latest_ardrone_img = None

In [None]:
def vicon_callback(data):
  global latest_vicon_state
  latest_vicon_state = np.array([
    data.transform.translation.x,
    data.transform.translation.y,
    data.transform.translation.z,
    data.transform.rotation.x,
    data.transform.rotation.y,
    data.transform.rotation.z,
    data.transform.rotation.w,
  ])
    
def img_callback(data):
  global latest_ardrone_img
  latest_ardrone_img = np.fromstring(data.data, np.uint8).reshape((data.height, data.width, 3))

Setup [ROS](http://www.ros.org/) interfaces for [ARDrone](https://ardrone-autonomy.readthedocs.io/en/latest/) and [Vicon](https://github.com/ethz-asl/vicon_bridge)

In [None]:
roscore
cd ~/ardrone_ws; source devel/setup.bash; rosrun ardrone_autonomy ardrone_driver -ip 192.168.42.1
cd ~/catkin_ws; source devel/setup.bash; roslaunch vicon_bridge vicon.launch
rostopic echo ardrone/navdata

In [None]:
rospy.init_node('ed', anonymous=True)

In [None]:
rospy.Subscriber('vicon/ardrone/main', TransformStamped, vicon_callback)
rospy.Subscriber('ardrone/image_raw', Image, img_callback)

In [None]:
publishers = {
  'cmd_vel': rospy.Publisher('cmd_vel', Twist, queue_size=1),
  'takeoff': rospy.Publisher('ardrone/takeoff', Empty, queue_size=1),
  'land': rospy.Publisher('ardrone/land', Empty, queue_size=1),
  'reset': rospy.Publisher('ardrone/reset', Empty, queue_size=1)
}

In [None]:
def ardronecmdvel(data):
  linear = Vector3(x=data['linear_x'], y=data['linear_y'], z=data['linear_z'])
  angular = Vector3(x=data['angular_x'], y=data['angular_y'], z=data['angular_z'])
  msg = Twist(linear=linear, angular=angular)
  publishers['cmd_vel'].publish(msg)

def ardronecmd(cmd):
  publishers[cmd].publish(Empty())

In [None]:
def twist_of_quat(quat):
  return quat2euler(quat.elements)[0]

In [None]:
n_act_dim = 9
n_obs_dim = 10

In [None]:
NOOP = 8

In [None]:
class ARDrone(gym.Env):
  metadata = {'render.modes': ['human']}
  
  def __init__(self):      
    self.action_space = spaces.Discrete(n_act_dim)
    obs_low = -np.array([np.inf] * n_obs_dim)
    obs_high = -obs_low
    self.observation_space = spaces.Box(obs_low, obs_high)
    
    self.img_height = 368
    self.img_width = 640
    self.ground_contact_z_thresh = 0.1
    self.max_pos = np.array([2, 2, 2])
    self.max_ep_len = 100000
    self.max_ep_duration = 30
    self.start_z_thresh = 0.25
    
    self.max_xy_speed = 0.5
    self.min_n_stab_noops = 25
    self.n_stabilizing_noops = 0
    self.landed = True
    
    self.terrain = self._init_terrain()
    
    self.goal_dist_thresh = 0.3
    self.pad_pos = self._init_pad_pos()
    self.obj_idx = None
    self.obj_names = ['the red chair', 'the gray chair', 'the door', 'the white styrofoam boards that say R-TECH']
    
    self.translation = None
    self.init_twist = None
    self.prev_pos = None
    self.prev_time = None
    self.prev_shaping = None
    self.curr_step = None
    self.ep_start_time = None
    self.prev_time_left = None
        
    self.viewer = None
        
    self.set_transform()
    
  def _init_terrain(self):
    return np.ones((10, 10))
    
  def _get_vicon_pos(self):
    return copy(latest_vicon_state)
  
  def _init_pad_pos(self):
    return copy(latest_vicon_state)
    
  def set_transform(self, pos=None):
    if pos is None:
      pos = self._get_vicon_pos()
    self.translation = -pos[:3]
    self.init_twist = twist_of_quat(Quaternion(pos[3:7]))
    
  def _transform_pos(self, pos):
    pos[:3] += self.translation
    pos[6] = twist_of_quat(Quaternion(pos[3:7])) - self.init_twist
    if pos[6] < 0:
      pos[6] += 2*math.pi
    return pos
    
  def _obs(self):
    pos = self._transform_pos(self._get_vicon_pos())
    pos = np.concatenate((pos[:3], pos[6:7]))
    curr_time = time.time()
    if self.prev_pos is not None:
      vel = (pos - self.prev_pos) / (curr_time - self.prev_time)
    else:
      vel = np.zeros(pos.shape)
    self.prev_pos = pos
    self.prev_time = copy(curr_time)
    
    rot_ang = -pos[3]
    rot = np.array([[np.cos(rot_ang), -np.sin(rot_ang)], [np.sin(rot_ang), np.cos(rot_ang)]])
    delta_xy_to_goal = rot.dot(self._get_pad_pos()[:2] - pos[:2])
        
    return np.concatenate((pos, vel, delta_xy_to_goal))

  def _get_pad_pos(self):
    return self.pad_pos[:3] + self.translation
  
  def _at_site(self):
    pos = self._transform_pos(self._get_vicon_pos())
    return np.linalg.norm(pos[:2] - self._get_pad_pos()[:2]) <= self.goal_dist_thresh
  
  def _exec_action(self, action, xy_speed=0.1, z_speed=0.5, w_speed=1):
    curr_xy_speed = np.linalg.norm(self._obs()[4:6])
    if curr_xy_speed >= self.max_xy_speed:
      self.n_stabilizing_noops = self.min_n_stab_noops
    if self.n_stabilizing_noops > 0:
      action = 8
      self.n_stabilizing_noops -= 1
    
    vel = {
      'linear_x': 0,
      'linear_y': 0,
      'linear_z': 0,
      'angular_x': 0,
      'angular_y': 0,
      'angular_z': 0
    }
    
    if action == 0:
      vel['linear_x'] = xy_speed
    elif action == 1:
      vel['linear_x'] = -xy_speed
    elif action == 2:
      vel['linear_y'] = xy_speed
    elif action == 3:
      vel['linear_y'] = -xy_speed
    elif action == 4:
      pass
    elif action == 5:
      if curr_xy_speed < 0.1:
        self._exec_cmd('land')
        return
    elif action == 6:
      vel['angular_z'] = w_speed
    elif action == 7:
      vel['angular_z'] = -w_speed
    elif action == NOOP: # 8
      pass
    else:
      raise ValueError
    
    ardronecmdvel(vel)
      
  def _out_of_bounds(self, pos=None):
    if pos is None:
      pos = self._transform_pos(self._get_vicon_pos())
    return (np.abs(pos[:3]) >= self.max_pos).any()
  
  def _step(self, action):
    self._exec_action(action)
    
    obs = self._obs()
        
    dist_to_goal = np.linalg.norm(obs[-2:])
    if dist_to_goal < 0.5*self.goal_dist_thresh:
      shaping = -100*obs[2]
    else:
      shaping = -100*dist_to_goal
    r = shaping
    if self.prev_shaping is not None:
      r -= self.prev_shaping
    self.prev_shaping = shaping
    
    oob = self._out_of_bounds()
    timeout = time.time() - self.ep_start_time > self.max_ep_duration
    on_ground = obs[2] <= self.ground_contact_z_thresh or self.landed
    done = oob or timeout or on_ground
    
    self.curr_step += 1
    
    at_loc = self._at_site()
        
    info = {}
    if done:
      self._land()
      info['duration'] = time.time() - self.ep_start_time
      info['oob'] = oob
      info['timeout'] = timeout
      info['on_ground'] = on_ground
      info['at_loc'] = at_loc
        
      if oob:
        print('You flew out of bounds.')
      if timeout:
        print('You ran out of time.')
      if at_loc and on_ground:
        print('You landed on the landing pad.')
      elif on_ground and not at_loc:
        print('You missed the landing pad.')
      
      at_rot = None
      while at_rot is None or at_rot not in ['y', 'n']:
        at_rot = input('Is the camera pointed at %s? (y/n): ' % self.obj_names[self.obj_idx])
      print('')
      info['at_rot'] = at_rot
      info['final_img'] = copy(latest_ardrone_img)
      info['goal_obj'] = self.obj_names[self.obj_idx]
      
      if oob or (on_ground and not at_loc):
        r = -100
      elif at_loc and on_ground and at_rot == 'y':
        r = 100
    
    return obs, r, done, info
    
  def _takeoff(self):
    self._exec_cmd('takeoff')
    pos = self._transform_pos(self._get_vicon_pos())
    while pos[2] < self.start_z_thresh:
      pos = self._transform_pos(self._get_vicon_pos())
    self.landed = False
      
  def _exec_cmd(self, cmd):
    if cmd == 'land':
      os.system('rostopic pub --once ardrone/land std_msgs/Empty')
      self.landed = True
    else:
      ardronecmd(cmd)
    
  def _land(self):
    self._exec_cmd('land')
    
  def _reset(self):
    self._land()
    
    self.prev_pos = None
    self.prev_time = None
    self.prev_shaping = None
    self.prev_obs = None
    self.curr_step = 0
    self.n_stabilizing_noops = 0
    self.prev_time_left = None
    
    self.obj_idx = random.randint(0, 3)
    
    print('Your goal for this episode is to point the camera at %s.' % self.obj_names[self.obj_idx])
    
    self.pad_pos = -self.translation
        
    pause = input('Experimenter, place the drone and hit ENTER: ')
    while self._out_of_bounds():
      pause = input('The drone is out of bounds. Experimenter, try again and hit ENTER: ')
    
    self._takeoff()
        
    print('The episode has begun.')
      
    self.ep_start_time = time.time()
    return self._step(NOOP)[0] # noop
  
  def _render(self, mode='human', close=False, godmode=False):
    if close:
      if self.viewer is not None:
        self.viewer.close()
        self.viewer = None
      return
    
    if self.viewer is None:
      self.viewer = rendering.SimpleImageViewer()
    
    if godmode:
        fig = plt.figure(figsize=(10, 10))
        canvas = FigureCanvas(fig)
        ax = plt.axes(xlim=(-0.5, self.terrain.shape[0] - 0.5), ylim=(-0.5, self.terrain.shape[1] - 0.5))
        pursuer_pos = self._transform_pos(self._get_vicon_pos())
        evader_loc = (self._get_pad_pos()[:2] + self.max_pos[:2]) / (2 * self.max_pos[:2])
        evader_loc[1] = 1 - evader_loc[1]
        evader_loc[0] = evader_loc[0] * self.terrain.shape[0]-0.5
        evader_loc[1] = evader_loc[1] * self.terrain.shape[1]-0.5
        ax.scatter([evader_loc[1]], [evader_loc[0]], s=500, color='red', linewidth=0, alpha=0.75)
        pursuer_size = 500 + 10000 * pursuer_pos[2] / self.max_pos[2]
        pursuer_loc = (pursuer_pos[:2] + self.max_pos[:2]) / (2 * self.max_pos[:2])
        pursuer_loc[1] = 1 - pursuer_loc[1]
        pursuer_loc[0] = pursuer_loc[0] * self.terrain.shape[0]-0.5
        pursuer_loc[1] = pursuer_loc[1] * self.terrain.shape[1]-0.5
        ax.scatter([pursuer_loc[1]], [pursuer_loc[0]], s=pursuer_size, c='blue', linewidth=0, alpha=0.75)
        ang = self._obs()[3]
        ax.set_ylabel('ang = %f' % ang, fontsize=30)

        ax.set_title('%d seconds left' % (self.max_ep_duration - ((time.time() - self.ep_start_time) if self.ep_start_time is not None else 0)),
                     fontsize=30)
        ax.set_xlabel('z = %0.2f' % (pursuer_pos[2] / self.max_pos[2]),
                     fontsize=30)
        
        ax.set_xticks([])
        ax.set_yticks([])
        agg = canvas.switch_backends(FigureCanvas)
        agg.draw()
        width, height = fig.get_size_inches() * fig.get_dpi()
        self.viewer.imshow(np.fromstring(agg.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3))
        plt.close()
    else:
        fig = plt.figure(figsize=(10, 8))
        canvas = FigureCanvas(fig)
        ax = plt.axes()
        time_left = self.max_ep_duration - ((time.time() - self.ep_start_time) if self.ep_start_time is not None else 0)
        time_left = max(0, time_left - 5)
        ax.set_xlabel('%d seconds left' % (time_left),
                     fontsize=60)
        
        ax.set_xticks([])
        ax.set_yticks([])
        agg = canvas.switch_backends(FigureCanvas)
        agg.draw()
        width, height = fig.get_size_inches() * fig.get_dpi()
        img_arr = np.fromstring(agg.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)
        self.viewer.imshow(img_arr)
        plt.close()

In [None]:
env = ARDrone()

In [None]:
max_ep_len = env.max_ep_len
def run_ep(policy, env, max_ep_len=max_ep_len, render=False, pilot_is_human=False):
    if pilot_is_human:
      global human_agent_action
      human_agent_action = init_human_action()
    obs = env.reset()
    done = False
    totalr = 0.
    trajectory = [obs]
    actions = []
    for step_idx in range(max_ep_len+1):
        if done:
            break
        action = policy(obs[None, :])
        obs, r, done, info = env.step(action)
        actions.append(action)
        trajectory.append(obs)
        if render:
          env.render()
        totalr += r
    outcome = r if r % 100 == 0 else 0
    return totalr, outcome, trajectory, actions, info

In [None]:
def noop_pilot_policy(obs):
  return 8

In [None]:
init_human_action = lambda: NOOP
human_agent_action = init_human_action()

action_of_key = {
  pygkey.UP: 0,
  pygkey.DOWN: 1,
  pygkey.LEFT: 2,
  pygkey.RIGHT: 3,
  pygkey.W: 4,
  pygkey.S: 5,
  pygkey.A: 6,
  pygkey.D: 7,
}

def key_press(key, mod):
  k = int(key)
  if k in action_of_key:
    global human_agent_action
    human_agent_action = action_of_key[k]

def key_release(key, mod):
  k = int(key)
  if k in action_of_key:
    global human_agent_action
    human_agent_action = 8
      
def human_pilot_policy(obs):
  return human_agent_action

In [None]:
def save_tf_vars(scope, path):
  sess = U.get_session()
  saver = tf.train.Saver([v for v in tf.global_variables() if v.name.startswith(scope + '/')])
  saver.save(sess, save_path=path)

In [None]:
def load_tf_vars(scope, path):
  sess = U.get_session()
  saver = tf.train.Saver([v for v in tf.global_variables() if v.name.startswith(scope + '/')])
  saver.restore(sess, path)

train assistive copilot

In [None]:
n_training_episodes = 500

In [None]:
make_q_func = lambda: deepq_mlp([64, 64])

In [None]:
copilot_dqn_learn_kwargs = {
  'lr': 1e-4,
  'exploration_fraction': 0.1,
  'exploration_final_eps': 0.02,
  'target_network_update_freq': 3000,
  'print_freq': 100,
  'num_cpu': 5,
  'gamma': 0.99
}

In [None]:
def onehot_encode(i, n=n_act_dim):
    x = np.zeros(n)
    x[i] = 1
    return x

def onehot_decode(x):
    l = np.nonzero(x)[0]
    assert len(l) == 1
    return l[0]

In [None]:
def make_co_env(pilot_policy, **extras):
  env = ARDrone()
  env.unwrapped.pilot_policy = pilot_policy
  return env

In [None]:
def co_build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None, using_control_sharing=True):
  with tf.variable_scope(scope, reuse=reuse):
    observations_ph = U.ensure_tf_input(make_obs_ph("observation"))
    if using_control_sharing:
      pilot_action_ph = tf.placeholder(tf.int32, (), name='pilot_action')
      pilot_tol_ph = tf.placeholder(tf.float32, (), name='pilot_tol')
    else:
      eps = tf.get_variable("eps", (), initializer=tf.constant_initializer(0))
      stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
      update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")

    q_values = q_func(observations_ph.get(), num_actions, scope="q_func")

    batch_size = tf.shape(q_values)[0]

    if using_control_sharing:
      q_values -= tf.reduce_min(q_values, axis=1)
      opt_actions = tf.argmax(q_values, axis=1, output_type=tf.int32)
      opt_q_values = tf.reduce_max(q_values, axis=1)

      batch_idxes = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
      reshaped_batch_size = tf.reshape(batch_size, [1])

      pi_actions = tf.tile(tf.reshape(pilot_action_ph, [1]), reshaped_batch_size)
      pi_act_idxes = tf.concat([batch_idxes, tf.reshape(pi_actions, [batch_size, 1])], axis=1)
      pi_act_q_values = tf.gather_nd(q_values, pi_act_idxes)
      
      actions = tf.where(pi_act_q_values >= (1 - pilot_tol_ph) * opt_q_values, pi_actions, opt_actions)
      
      act = U.function(inputs=[
        observations_ph, pilot_action_ph, pilot_tol_ph
      ],
                       outputs=[actions])
    else:
      deterministic_actions = tf.argmax(q_values, axis=1)

      random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64)
      chose_random = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
      stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions)

      output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions)
      update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
      act = U.function(inputs=[observations_ph, stochastic_ph, update_eps_ph],
                       outputs=[output_actions],
                       givens={update_eps_ph: -1.0, stochastic_ph: True},
                       updates=[update_eps_expr])
    return act

In [None]:
def co_build_train(make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=None, gamma=1.0,
    double_q=True, scope="deepq", reuse=None, using_control_sharing=True):
    act_f = co_build_act(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse, using_control_sharing=using_control_sharing)

    with tf.variable_scope(scope, reuse=reuse):
        # set up placeholders
        obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t"))
        act_t_ph = tf.placeholder(tf.int32, [None], name="action")
        rew_t_ph = tf.placeholder(tf.float32, [None], name="reward")
        obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1"))
        done_mask_ph = tf.placeholder(tf.float32, [None], name="done")
        importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight")

        obs_t_input_get = obs_t_input.get()
        obs_tp1_input_get = obs_tp1_input.get()

        # q network evaluation
        q_t = q_func(obs_t_input_get, num_actions, scope='q_func', reuse=True)  # reuse parameters from act
        q_func_vars = U.scope_vars(U.absolute_scope_name('q_func'))

        # target q network evalution
        q_tp1 = q_func(obs_tp1_input_get, num_actions, scope="target_q_func")
        target_q_func_vars = U.scope_vars(U.absolute_scope_name("target_q_func"))

        # q scores for actions which we know were selected in the given state.
        q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions), 1)

        # compute estimate of best possible value starting from state at t + 1
        if double_q:
            q_tp1_using_online_net = q_func(obs_tp1_input_get, num_actions, scope='q_func', reuse=True)
            q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1)
            q_tp1_best = tf.reduce_sum(q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions), 1)
        else:
            q_tp1_best = tf.reduce_max(q_tp1, 1)
        q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best

        # compute RHS of bellman equation
        q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked

        # compute the error (potentially clipped)
        td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
        errors = U.huber_loss(td_error)
        weighted_error = tf.reduce_mean(importance_weights_ph * errors)

        # compute optimization op (potentially with gradient clipping)
        if grad_norm_clipping is not None:
            optimize_expr = U.minimize_and_clip(optimizer,
                                                weighted_error,
                                                var_list=q_func_vars,
                                                clip_val=grad_norm_clipping)
        else:
            optimize_expr = optimizer.minimize(weighted_error, var_list=q_func_vars)

        # update_target_fn will be called periodically to copy Q network to target Q network
        update_target_expr = []
        for var, var_target in zip(sorted(q_func_vars, key=lambda v: v.name),
                                   sorted(target_q_func_vars, key=lambda v: v.name)):
            update_target_expr.append(var_target.assign(var))
        update_target_expr = tf.group(*update_target_expr)

        # Create callable functions
        train = U.function(
            inputs=[
                obs_t_input,
                act_t_ph,
                rew_t_ph,
                obs_tp1_input,
                done_mask_ph,
                importance_weights_ph
            ],
            outputs=td_error,
            updates=[optimize_expr]
        )
        update_target = U.function([], [], updates=[update_target_expr])

        q_values = U.function([obs_t_input], q_t)

    return act_f, train, update_target, {'q_values': q_values}

In [None]:
def co_dqn_learn(
    env,
    q_func,
    lr=1e-3,
    max_timesteps=100000,
    buffer_size=50000,
    train_freq=1,
    batch_size=32,
    print_freq=1,
    checkpoint_freq=10000,
    learning_starts=1000,
    gamma=1.0,
    target_network_update_freq=500,
    exploration_fraction=0.1,
    exploration_final_eps=0.02,
    num_cpu=5,
    callback=None,
    scope='deepq',
    pilot_tol=0,
    pilot_is_human=False,
    reuse=False):
    
    # Create all the functions necessary to train the model

    sess = U.get_session()
    if sess is None:
      sess = U.make_session(num_cpu=num_cpu)
      sess.__enter__()

    def make_obs_ph(name):
        return U.BatchInput(env.observation_space.shape, name=name)
      
    using_control_sharing = pilot_tol > 0
    
    act, train, update_target, debug = co_build_train(
        scope=scope,
        make_obs_ph=make_obs_ph,
        q_func=q_func,
        num_actions=env.action_space.n,
        optimizer=tf.train.AdamOptimizer(learning_rate=lr),
        gamma=gamma,
        grad_norm_clipping=10,
        reuse=reuse,
        using_control_sharing=using_control_sharing
    )
    
    act_params = {
        'make_obs_ph': make_obs_ph,
        'q_func': q_func,
        'num_actions': env.action_space.n,
    }

    replay_buffer = ReplayBuffer(buffer_size)

    # Initialize the parameters and copy them to the target network.
    U.initialize()
    update_target()
    
    if max_timesteps == 0:
      return ActWrapper(act, act_params)

    episode_trajectories = []
    episode_actions = []
    episode_rewards = []
    episode_outcomes = []
    saved_mean_reward = None
    obs = env.reset()
    prev_t = 0
    episode_reward = 0
    rollouts = []
    episode_trajectory = []
    episode_action = []
    
    if pilot_is_human:
      global human_agent_action
      human_agent_action = init_human_action()
    
    #if not using_control_sharing:
    exploration = LinearSchedule(schedule_timesteps=int(exploration_fraction * max_timesteps),
                               initial_p=1.0,
                               final_p=exploration_final_eps)

    with tempfile.TemporaryDirectory() as td:
        model_saved = False
        model_file = os.path.join(td, 'model')
        for t in range(max_timesteps):
            episode_trajectory.append(obs)
            
            act_kwargs = {}
            if using_control_sharing:
              pilot_action = env.unwrapped.pilot_policy(obs[None, :n_obs_dim])
              act_kwargs['pilot_action'] = pilot_action
              act_kwargs['pilot_tol'] = pilot_tol if pilot_action != 8 else 0
            else:
              act_kwargs['update_eps'] = exploration.value(t)
              
            action = act(obs[None, :], **act_kwargs)[0][0]
            #if np.random.random() < exploration.value(t):
            #  action = random.randint(0, 8) # DEBUG
            new_obs, rew, done, info = env.step(action)
            episode_action.append(action)

            if pilot_is_human:
              env.render()

            # Store transition in the replay buffer.
            replay_buffer.add(obs, action, rew, new_obs, float(done))
            obs = new_obs

            episode_reward += rew

            if done:
                if t > learning_starts:
                  for _ in range(t - prev_t):
                    obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size)
                    weights, batch_idxes = np.ones_like(rewards), None
                    td_errors = train(obses_t, actions, rewards, obses_tp1, dones, weights)

                obs = env.reset()

                episode_outcomes.append(rew)
                episode_rewards.append(episode_reward)
                episode_trajectories.append(episode_trajectory + [new_obs])
                episode_actions.append(episode_action)
                episode_trajectory = []
                episode_action = []
                episode_reward = 0

                if pilot_is_human:
                  global human_agent_action
                  human_agent_action = init_human_action()

                prev_t = t
                    
                if pilot_is_human:
                  time.sleep(2)

            if t > learning_starts and t % target_network_update_freq == 0:
                # Update target network periodically.
                update_target()

            mean_100ep_reward = round(np.mean(episode_rewards[-100:]), 1)
            mean_100ep_succ = round(np.mean([1 if x==100 else 0 for x in episode_outcomes[-100:]]), 2)
            mean_100ep_crash = round(np.mean([1 if x==-100 else 0 for x in episode_outcomes[-100:]]), 2)
            num_episodes = len(episode_rewards)
            if done and print_freq is not None and len(episode_rewards) % print_freq == 0:
                logger.record_tabular("steps", t)
                logger.record_tabular("episodes", num_episodes)
                logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
                logger.record_tabular("mean 100 episode succ", mean_100ep_succ)
                logger.record_tabular("mean 100 episode crash", mean_100ep_crash)
                logger.dump_tabular()

            if checkpoint_freq is not None and t > learning_starts and num_episodes > 100 and t % checkpoint_freq == 0 and (saved_mean_reward is None or mean_100ep_reward > saved_mean_reward):
                if print_freq is not None:
                    print('Saving model due to mean reward increase:')
                    print(saved_mean_reward, mean_100ep_reward)
                U.save_state(model_file)
                model_saved = True
                saved_mean_reward = mean_100ep_reward

        if model_saved:
            U.load_state(model_file)

    reward_data = {
      'rewards': episode_rewards,
      'outcomes': episode_outcomes,
      'trajectories': episode_trajectories,
      'actions': episode_actions
    }
          
    return ActWrapper(act, act_params), reward_data

In [None]:
def make_co_policy(
  env, scope=None, pilot_tol=0, pilot_is_human=False, 
  n_eps=n_training_episodes, copilot_scope=None, 
  copilot_q_func=None,
  reuse=False, **extras):
  
  if copilot_scope is not None:
    scope = copilot_scope
  elif scope is None:
    scope = str(uuid.uuid4())
  q_func = copilot_q_func if copilot_scope is not None else make_q_func()
    
  return (scope, q_func), co_dqn_learn(
    env,
    scope=scope,
    q_func=q_func,
    max_timesteps=max_ep_len*n_eps,
    pilot_tol=pilot_tol,
    pilot_is_human=pilot_is_human,
    reuse=reuse,
    **copilot_dqn_learn_kwargs
  )

load pretrained copilot

In [None]:
copilot_path = os.path.join(data_dir, 'pretrained_noop_copilot')
copilot_scope = ''

In [None]:
co_env = make_co_env(noop_pilot_policy)

In [None]:
(scope, q_func), raw_copilot_policy = make_co_policy(
  co_env, pilot_tol=1e-3, pilot_is_human=False, n_eps=0,
  copilot_scope=copilot_scope,
  copilot_q_func=make_q_func(),
  reuse=False,
  pilot_policy=noop_pilot_policy
)

In [None]:
load_tf_vars(copilot_scope, copilot_path)

evaluate solo pilot

In [None]:
pilot_id = 'spike'

In [None]:
n_eval_eps = 20

In [None]:
env.render()
env.unwrapped.viewer.window.on_key_press = key_press
env.unwrapped.viewer.window.on_key_release = key_release

In [None]:
rollouts = []

In [None]:
rollouts_checkpoint_path = os.path.join(data_dir, '%s_rollouts_checkpoint.pkl' % pilot_id)

In [None]:
with open(rollouts_checkpoint_path, 'rb') as f:
  rollouts = pickle.load(f)

In [None]:
len(rollouts)

In [None]:
for _ in range(n_eval_eps - len(rollouts)):
  print('This will be episode %d of %d' % (len(rollouts)+1, n_eval_eps))
  rollouts.append(run_ep(human_pilot_policy, env, render=True))

In [None]:
with open(rollouts_checkpoint_path, 'wb') as f:
  pickle.dump(rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
env.close()

In [None]:
eval_of_pilot = {pilot_id: list(zip(*rollouts))}

In [None]:
with open(os.path.join(data_dir, '%s_pilot_eval.pkl' % pilot_id), 'wb') as f:
  pickle.dump(eval_of_pilot, f, pickle.HIGHEST_PROTOCOL)

evaluate with copilot

In [None]:
def copilot_policy(obs):
  with tf.variable_scope(copilot_scope, reuse=None):
    pilot_action = human_pilot_policy(obs[None, :n_obs_dim])
    pilot_tol = 1 if pilot_action in [5, 6, 7] else 0
    return raw_copilot_policy._act(
      obs, 
      pilot_tol=pilot_tol, 
      pilot_action=pilot_action
    )[0][0]

In [None]:
n_eval_eps = 20

In [None]:
co_env = make_co_env(pilot_policy=copilot_policy)

In [None]:
co_env.render()
co_env.unwrapped.viewer.window.on_key_press = key_press
co_env.unwrapped.viewer.window.on_key_release = key_release

In [None]:
assisted_rollouts = []

In [None]:
assisted_rollouts_checkpoint_path = os.path.join(data_dir, '%s_assisted_rollouts_checkpoint.pkl' % pilot_id)

In [None]:
with open(assisted_rollouts_checkpoint_path, 'rb') as f:
  assisted_rollouts = pickle.load(f)

In [None]:
len(assisted_rollouts)

In [None]:
for _ in range(n_eval_eps - len(assisted_rollouts)):
  print('This will be episode %d of %d' % (len(assisted_rollouts)+1, n_eval_eps))
  assisted_rollouts.append(run_ep(copilot_policy, co_env, render=True))

In [None]:
with open(assisted_rollouts_checkpoint_path, 'wb') as f:
  pickle.dump(assisted_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
co_env.close()

In [None]:
eval_of_copilot = {pilot_id: list(zip(*assisted_rollouts))}

In [None]:
with open(os.path.join(data_dir, '%s_copilot_eval.pkl' % pilot_id), 'wb') as f:
  pickle.dump(eval_of_copilot, f, pickle.HIGHEST_PROTOCOL)