# TO DO

- [ ] add reward scaling
- [ ] override train_agent_batch_with_evaluation
- [ ] support calllable object(hook)

In [1]:
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future import standard_library
standard_library.install_aliases()  # NOQA
import argparse
import sys
import logging
from collections import deque

import chainer
from chainer import optimizers
import gym
from gym import spaces
import gym.wrappers
import numpy as np

import chainerrl
from chainerrl.agents.ddpg import DDPG
from chainerrl.agents.ddpg import DDPGModel
from chainerrl import experiments
from chainerrl import explorers
from chainerrl import misc
from chainerrl import policy
from chainerrl import q_functions
from chainerrl import replay_buffer

from flow.multiagent_envs import MultiWaveAttenuationMergePOEnv
from flow.scenarios import MergeScenario
from flow.utils.registry import make_create_env

benchmark_name = 'multi_merge'

In [5]:
gpu = None
seed = 0
final_exploration_steps = 10**6
actor_lr = 1e-4
critic_lr = 1e-3
steps = 10 ** 7
n_hidden_channels = 300
n_hidden_layers = 3
replay_start_size = 5000
n_update_times = 1
target_update_interval = 1
target_update_method = 'soft'
soft_update_tau = 1e-2
update_interval = 4
eval_n_runs = 100
eval_interval = 10**5
gamma = 0.995
minibatch_size = 200
use_bn = True
reward_scale_factor = 1e-2
return_window_size = 100
step_offset = 0
log_interval = 5
num_envs = 2

In [6]:
# Set a random seed used in ChainerRL
misc.set_random_seed(seed)

# Set different random seeds for different subprocesses.
# If seed=0 and processes=4, subprocess seeds are [0, 1, 2, 3].
# If seed=1 and processes=4, subprocess seeds are [4, 5, 6, 7].
process_seeds = np.arange(num_envs) + seed * num_envs
assert process_seeds.max() < 2 ** 32

outdir = experiments.prepare_output_dir('result')

In [7]:
benchmark  = __import__(
    "flow.benchmarks.%s" % benchmark_name, fromlist=["flow_params"])
flow_params = benchmark.flow_params
HORIZON = flow_params['env'].horizon

create_env, env_name = make_create_env(params=flow_params, version=0)

In [8]:
def make_env(create_env):
    def _thunk():
        env = create_env()
        return env
    return _thunk

def make_batch_env(test):
    return chainerrl.envs.MultiprocessVectorEnv(
        [make_env(create_env) for i in range(num_envs)])

In [9]:
sample_env = create_env()

In [10]:
timestep_limit = flow_params["env"].horizon
obs_size = np.asarray(sample_env.observation_space.shape).prod()
action_space = sample_env.action_space

In [11]:
action_size = np.asarray(action_space.shape).prod()
if use_bn:
    q_func = q_functions.FCBNLateActionSAQFunction(
        obs_size, action_size,
        n_hidden_channels=n_hidden_channels,
        n_hidden_layers=n_hidden_layers,
        normalize_input=True)
    pi = policy.FCBNDeterministicPolicy(
        obs_size, action_size=action_size,
        n_hidden_channels=n_hidden_channels,
        n_hidden_layers=n_hidden_layers,
        min_action=action_space.low, max_action=action_space.high,
        bound_action=True,
        normalize_input=True)
else:
    q_func = q_functions.FCSAQFunction(
        obs_size, action_size,
        n_hidden_channels=n_hidden_channels,
        n_hidden_layers=n_hidden_layers)
    pi = policy.FCDeterministicPolicy(
        obs_size, action_size=action_size,
        n_hidden_channels=n_hidden_channels,
        n_hidden_layers=n_hidden_layers,
        min_action=action_space.low, max_action=action_space.high,
        bound_action=True)

In [12]:
model = DDPGModel(q_func=q_func, policy=pi)

In [13]:
model = DDPGModel(q_func=q_func, policy=pi)
opt_a = optimizers.Adam(alpha=actor_lr)
opt_c = optimizers.Adam(alpha=critic_lr)
opt_a.setup(model['policy'])
opt_c.setup(model['q_function'])
opt_a.add_hook(chainer.optimizer.GradientClipping(1.0), 'hook_a')
opt_c.add_hook(chainer.optimizer.GradientClipping(1.0), 'hook_c')

rbuf = replay_buffer.ReplayBuffer(5 * 10 ** 5)

bufferには
1. state: self.batch_last_obs[i]
2. action: self.batch_last_action[i]
3. reward: batch_reward[i]
4. next_state: batch_obs[i]
5. next_action: action or None
6. is_state_terminal: batch_done

が追加される

一方、envからは
1. batch_obs: [{"key": "value"}, {"key": nparray}]
2. batch_rew: [{"key": "value"}, {"key": float}]
3. batch_done: [{"key": "value"}, {"key": bool}]

が追加される

agentが可能なのは
1. batch_act: obs[np.array, ...]
2. batch_train_act: obs[np.array, ...]
3. batch_obs_and_train: obs[np.array, ...], rew[float, ...], done[bool, ...]


In [14]:
class DDPG_MA(DDPG):
    def batch_act(self, batch_obs):
        """Select a batch of actions for evaluation.
        Args:
            batch_obs (Sequence of ~object): Observations.
        Returns:
            Sequence of ~object: Actions.
        """

        batch_actions = []
        for env_obs in batch_obs:
            keys, obss = list(env_obs.keys()), list(env_obs.values())
            obss = [obs.astype(np.float32) for obs in obss]
            
            with chainer.using_config('train', False), chainer.no_backprop_mode():
                batch_xs = self.batch_states(obss, self.xp, self.phi)
                batch_action = self.policy(batch_xs).sample()
                # Q is not needed here, but log it just for information
                q = self.q_function(batch_xs, batch_action)

            batch_actions.append({key:action for key, action in zip(keys, batch_action)})
            
            # Update stats
            self.average_q *= self.average_q_decay
            self.average_q += (1 - self.average_q_decay) * float(
                q.array.mean(axis=0))
            
        self.logger.debug('t:%s a:%s q:%s',
                          self.t, batch_action.array[0], q.array)
        return batch_actions

    def batch_act_and_train(self, batch_obs):
        """Select a batch of actions for training.
        Args:
            batch_obs (Sequence of ~object): Observations.
        Returns:
            Sequence of ~object: Actions.
        """

        batch_greedy_action = self.batch_act(batch_obs)
        batch_action = [
            self.explorer.select_action(
                self.t, lambda: batch_greedy_action[i])
            for i in range(len(batch_greedy_action))]

        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        return batch_action

    def batch_observe_and_train(
            self, batch_obs, batch_reward, batch_done, batch_reset):
        """Observe a batch of action consequences for training.
        Args:
            batch_obs (Sequence of ~object): Observations.
            batch_reward (Sequence of float): Rewards.
            batch_done (Sequence of boolean): Boolean values where True
                indicates the current state is terminal.
            batch_reset (Sequence of boolean): Boolean values where True
                indicates the current episode will be reset, even if the
                current state is not terminal.
        Returns:
            None
        """
        # when a new car is added to the highway
        if len(batch_obs) > len(self.batch_last_obs):
            self.batch_last_obs.append(None)
            self.batch_last_action.append(None)
            
        for i in range(len(batch_obs)):
            self.t += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                self.replay_buffer.append(
                    state=self.batch_last_obs[i],
                    action=self.batch_last_action[i],
                    reward=batch_reward[i],
                    next_state=batch_obs[i],
                    next_action=None,
                    is_state_terminal=batch_done[i],
                )
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
            self.replay_updater.update_if_necessary(self.t)

In [15]:
ou_sigma = (action_space.high - action_space.low) * 0.2
explorer = explorers.AdditiveOU(sigma=ou_sigma)
agent = DDPG_MA(model, opt_a, opt_c, rbuf, gamma=gamma,
             explorer=explorer, replay_start_size=replay_start_size,
             target_update_method=target_update_method,
             target_update_interval=target_update_interval,
             update_interval=update_interval,
             soft_update_tau=soft_update_tau,
             n_times_update=n_update_times,
             gpu=gpu, minibatch_size=minibatch_size)

In [20]:
agent.batch_act(batch_obs)

[{'flow_1.1': variable([0.08803371]), 'flow_1.2': variable([0.03473603])},
 {'flow_1.1': variable([0.08545651]), 'flow_1.2': variable([0.034552])}]

In [17]:
env = make_batch_env(False)

1. experiments.train_agent_batch_with_evaluation()

2. contents in train_agent_batch

    1. actions = agent.batch_act_and_train(obss)
    2. env.step(actions)
    3. compute resets
    4. agent.batch_observe_and_train(obss, rs, dones, resets): in this env, done==True when horizon is reached. So resets is same as dones

In [18]:
def batch_act_and_train_MA(batch_obs):
    """
    do batch_act_and_train in a multi-agent batch env
    batch is by the agents in an env, and for loop by the # of env
    """
    batch_actions = []
    for batch in batch_obs:
        keys, obss = list(batch.keys()), list(batch.values())
        obss = [obs.astype(np.float32) for obs in obss]
        actions = agent.batch_act_and_train(obss)
        batch_actions.append({key:action for key, action in zip(keys, actions)})
    return batch_actions

def batch_observe_and_train_MA(batch_obs, batch_rs, batch_dones):
    """
    do batch_observe_and_train in a multi-agent batch env
    batch is by the agents in an env, and for loop by the # of env
    """
    # train: the agents in a same env is batched, and for_loop by the env num
    for obss, rss, doness in zip(batch_obs, batch_rs, batch_dones):
        keys, obss = list(obss.keys()), list(obss.values())
        obss = [obs.astype(np.float32) for obs in obss]
        rss = [float(rs) for rs in rss.values()]
        doness = list([doness[key] for key in keys])  # without __all__
        resetss = doness
        print("last {}".format(len(agent.batch_last_obs)))
        print("current {}".format(len(obss)))
        agent.batch_observe_and_train(obss, rss, doness, resetss)   

In [19]:
logger = logging.getLogger(__name__)
recent_returns = deque(maxlen=return_window_size)

num_envs = env.num_envs
episode_r = np.zeros(num_envs, dtype=np.float64)
episode_idx = np.zeros(num_envs, dtype='i')
episode_len = np.zeros(num_envs, dtype='i')

# o_0, r_0
batch_obs = env.reset()
rs = np.zeros(num_envs, dtype='f')

t = step_offset

In [157]:
for _ in range(1):
    # a_t
    batch_actions = batch_act_and_train_MA(batch_obs)
    print(batch_actions)
    # o_{t+1}, r_{t+1}
    batch_obs, batch_rs, batch_dones, batch_infos = env.step(batch_actions)
    rs = [np.mean(list(rss.values())) for rss in batch_rs] # each env mean reward
    episode_r += rs
    episode_len += 1

    # mask for reset the env(when collision or horizon)
    batch_reset = [done["__all__"] for done in batch_dones]
    not_batch_reset = np.logical_not(batch_reset)  # doesn't reset when True

    # Agent observes the consequences
    batch_observe_and_train_MA(batch_obs, batch_rs, batch_dones)

    episode_idx += batch_reset
    recent_returns.extend(episode_r[batch_reset])

    for _ in range(num_envs):
        t += 1

    # logger should be here
    # evaluator should be here

    if t >= steps:
        break

    # Start new episodes if needed
    episode_r[batch_reset] = 0
    episode_len[batch_reset] = 0
    batch_obs = env.reset(not_batch_reset)

[{'flow_1.10': array([-0.40924096], dtype=float32), 'flow_1.9': array([-0.20380986], dtype=float32), 'flow_1.12': array([-2.7184482], dtype=float32), 'flow_1.11': array([-3.8482678], dtype=float32)}, {'flow_1.2': array([-1.7086694], dtype=float32), 'flow_1.1': array([-3.1255903], dtype=float32)}]
last 2
current 4


IndexError: list index out of range

In [85]:
episode_idx

array([0, 0], dtype=int32)

In [52]:
batch_reset

[False, False]