Skip to content
Permalink
Browse files

[rllib] Support continuous action distributions in IMPALA/APPO (#4771)

  • Loading branch information...
ericl committed May 17, 2019
1 parent ffd596d commit 7d5ef6d99c4f784bb50efad90b814ded3e46176b
@@ -95,7 +95,7 @@ Asynchronous Proximal Policy Optimization (APPO)
`[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/ppo/appo.py>`__
We include an asynchronous variant of Proximal Policy Optimization (PPO) based on the IMPALA architecture. This is similar to IMPALA but using a surrogate policy loss with clipping. Compared to synchronous PPO, APPO is more efficient in wall-clock time due to its use of asynchronous sampling. Using a clipped loss also allows for multiple SGD passes, and therefore the potential for better sample efficiency compared to IMPALA. V-trace can also be enabled to correct for off-policy samples.

This implementation is currently *experimental*. Consider also using `PPO <rllib-algorithms.html#proximal-policy-optimization-ppo>`__ or `IMPALA <rllib-algorithms.html#importance-weighted-actor-learner-architecture-impala>`__.
APPO is not always more efficient; it is often better to simply use `PPO <rllib-algorithms.html#proximal-policy-optimization-ppo>`__ or `IMPALA <rllib-algorithms.html#importance-weighted-actor-learner-architecture-impala>`__.

Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-appo.yaml>`__

@@ -13,7 +13,7 @@ Algorithm Discrete Actions Continuous Actions Multi-Agent Recurre
A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** **Yes**
PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** **Yes**
PG **Yes** `+parametric`_ **Yes** **Yes** **Yes**
IMPALA **Yes** `+parametric`_ No **Yes** **Yes**
IMPALA **Yes** `+parametric`_ **Yes** **Yes** **Yes**
DQN, Rainbow **Yes** `+parametric`_ No **Yes** No
DDPG, TD3 No **Yes** **Yes** No
APEX-DQN **Yes** `+parametric`_ No **Yes** No
@@ -34,6 +34,7 @@

import collections

from ray.rllib.models.action_dist import Categorical
from ray.rllib.utils import try_import_tf

tf = try_import_tf()
@@ -48,12 +49,15 @@
VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages")


def log_probs_from_logits_and_actions(policy_logits, actions):
return multi_log_probs_from_logits_and_actions([policy_logits],
[actions])[0]
def log_probs_from_logits_and_actions(policy_logits,
actions,
dist_class=Categorical):
return multi_log_probs_from_logits_and_actions([policy_logits], [actions],
dist_class)[0]


def multi_log_probs_from_logits_and_actions(policy_logits, actions):
def multi_log_probs_from_logits_and_actions(policy_logits, actions,
dist_class):
"""Computes action log-probs from policy logits and actions.
In the notation used throughout documentation and comments, T refers to the
@@ -68,11 +72,11 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions):
...,
[T, B, ACTION_SPACE[-1]]
with un-normalized log-probabilities parameterizing a softmax policy.
actions: A list with length of ACTION_SPACE of int32
actions: A list with length of ACTION_SPACE of
tensors of shapes
[T, B],
[T, B, ...],
...,
[T, B]
[T, B, ...]
with actions.
Returns:
@@ -87,8 +91,16 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions):

log_probs = []
for i in range(len(policy_logits)):
log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=policy_logits[i], labels=actions[i]))
p_shape = tf.shape(policy_logits[i])
a_shape = tf.shape(actions[i])
policy_logits_flat = tf.reshape(policy_logits[i],
tf.concat([[-1], p_shape[2:]], axis=0))
actions_flat = tf.reshape(actions[i],
tf.concat([[-1], a_shape[2:]], axis=0))
log_probs.append(
tf.reshape(
dist_class(policy_logits_flat).logp(actions_flat),
a_shape[:2]))

return log_probs

@@ -100,6 +112,7 @@ def from_logits(behaviour_policy_logits,
rewards,
values,
bootstrap_value,
dist_class=Categorical,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
name="vtrace_from_logits"):
@@ -111,6 +124,7 @@ def from_logits(behaviour_policy_logits,
rewards,
values,
bootstrap_value,
dist_class,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold,
name=name)
@@ -133,6 +147,7 @@ def multi_from_logits(behaviour_policy_logits,
rewards,
values,
bootstrap_value,
dist_class,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
name="vtrace_from_logits"):
@@ -168,11 +183,11 @@ def multi_from_logits(behaviour_policy_logits,
[T, B, ACTION_SPACE[-1]]
with un-normalized log-probabilities parameterizing the softmax target
policy.
actions: A list with length of ACTION_SPACE of int32
actions: A list with length of ACTION_SPACE of
tensors of shapes
[T, B],
[T, B, ...],
...,
[T, B]
[T, B, ...]
with actions sampled from the behaviour policy.
discounts: A float32 tensor of shape [T, B] with the discount encountered
when following the behaviour policy.
@@ -182,6 +197,7 @@ def multi_from_logits(behaviour_policy_logits,
wrt. the target policy.
bootstrap_value: A float32 of shape [B] with the value function estimate at
time T.
dist_class: action distribution class for the logits.
clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
importance weights (rho) when calculating the baseline targets (vs).
rho^bar in the paper.
@@ -208,13 +224,11 @@ def multi_from_logits(behaviour_policy_logits,
behaviour_policy_logits[i], dtype=tf.float32)
target_policy_logits[i] = tf.convert_to_tensor(
target_policy_logits[i], dtype=tf.float32)
actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32)

# Make sure tensor ranks are as expected.
# The rest will be checked by from_action_log_probs.
behaviour_policy_logits[i].shape.assert_has_rank(3)
target_policy_logits[i].shape.assert_has_rank(3)
actions[i].shape.assert_has_rank(2)

with tf.name_scope(
name,
@@ -223,9 +237,9 @@ def multi_from_logits(behaviour_policy_logits,
discounts, rewards, values, bootstrap_value
]):
target_action_log_probs = multi_log_probs_from_logits_and_actions(
target_policy_logits, actions)
target_policy_logits, actions, dist_class)
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
behaviour_policy_logits, actions)
behaviour_policy_logits, actions, dist_class)

log_rhos = get_log_rhos(target_action_log_probs,
behaviour_action_log_probs)
@@ -18,7 +18,6 @@
from ray.rllib.models.action_dist import MultiCategorical
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils import try_import_tf

@@ -40,6 +39,7 @@ def __init__(self,
rewards,
values,
bootstrap_value,
dist_class,
valid_mask,
vf_loss_coeff=0.5,
entropy_coeff=0.01,
@@ -52,7 +52,7 @@ def __init__(self,
handle episode cut boundaries.
Args:
actions: An int32 tensor of shape [T, B, ACTION_SPACE].
actions: An int|float32 tensor of shape [T, B, ACTION_SPACE].
actions_logp: A float32 tensor of shape [T, B].
actions_entropy: A float32 tensor of shape [T, B].
dones: A bool tensor of shape [T, B].
@@ -70,6 +70,7 @@ def __init__(self,
rewards: A float32 tensor of shape [T, B].
values: A float32 tensor of shape [T, B].
bootstrap_value: A float32 tensor of shape [B].
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
"""

@@ -78,11 +79,12 @@ def __init__(self,
self.vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=behaviour_logits,
target_policy_logits=target_logits,
actions=tf.unstack(tf.cast(actions, tf.int32), axis=2),
actions=tf.unstack(actions, axis=2),
discounts=tf.to_float(~dones) * discount,
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
dist_class=dist_class,
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
tf.float32))
@@ -140,30 +142,28 @@ def __init__(self,

if isinstance(action_space, gym.spaces.Discrete):
is_multidiscrete = False
actions_shape = [None]
output_hidden_shape = [action_space.n]
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
is_multidiscrete = True
actions_shape = [None, len(action_space.nvec)]
output_hidden_shape = action_space.nvec.astype(np.int32)
else:
raise UnsupportedSpaceException(
"Action space {} is not supported for IMPALA.".format(
action_space))
is_multidiscrete = False
output_hidden_shape = 1

# Create input placeholders
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
if existing_inputs:
actions, dones, behaviour_logits, rewards, observations, \
prev_actions, prev_rewards = existing_inputs[:7]
existing_state_in = existing_inputs[7:-1]
existing_seq_lens = existing_inputs[-1]
else:
actions = tf.placeholder(tf.int64, actions_shape, name="ac")
actions = ModelCatalog.get_action_placeholder(action_space)
dones = tf.placeholder(tf.bool, [None], name="dones")
rewards = tf.placeholder(tf.float32, [None], name="rewards")
behaviour_logits = tf.placeholder(
tf.float32, [None, sum(output_hidden_shape)],
name="behaviour_logits")
tf.float32, [None, logit_dim], name="behaviour_logits")
observations = tf.placeholder(
tf.float32, [None] + list(observation_space.shape))
existing_state_in = None
@@ -174,8 +174,6 @@ def __init__(self,
behaviour_logits, output_hidden_shape, axis=1)

# Setup the policy
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
prev_actions = ModelCatalog.get_action_placeholder(action_space)
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
self.model = ModelCatalog.get_model(
@@ -261,6 +259,7 @@ def make_time_major(tensor, drop_last=False):
rewards=make_time_major(rewards, drop_last=True),
values=make_time_major(values, drop_last=True),
bootstrap_value=make_time_major(values)[-1],
dist_class=dist_class,
valid_mask=make_time_major(mask, drop_last=True),
vf_loss_coeff=self.config["vf_loss_coeff"],
entropy_coeff=self.config["entropy_coeff"],
@@ -18,7 +18,6 @@
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.models.action_dist import MultiCategorical
from ray.rllib.evaluation.postprocessing import compute_advantages
@@ -94,6 +93,7 @@ def __init__(self,
rewards,
values,
bootstrap_value,
dist_class,
valid_mask,
vf_loss_coeff=0.5,
entropy_coeff=0.01,
@@ -107,18 +107,19 @@ def __init__(self,
handle episode cut boundaries.
Arguments:
actions: An int32 tensor of shape [T, B, NUM_ACTIONS].
actions: An int|float32 tensor of shape [T, B, logit_dim].
prev_actions_logp: A float32 tensor of shape [T, B].
actions_logp: A float32 tensor of shape [T, B].
action_kl: A float32 tensor of shape [T, B].
actions_entropy: A float32 tensor of shape [T, B].
dones: A bool tensor of shape [T, B].
behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
target_logits: A float32 tensor of shape [T, B, logit_dim].
discount: A float32 scalar.
rewards: A float32 tensor of shape [T, B].
values: A float32 tensor of shape [T, B].
bootstrap_value: A float32 tensor of shape [B].
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
"""

@@ -127,11 +128,12 @@ def __init__(self,
self.vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=behaviour_logits,
target_policy_logits=target_logits,
actions=tf.unstack(tf.cast(actions, tf.int32), axis=2),
actions=tf.unstack(actions, axis=2),
discounts=tf.to_float(~dones) * discount,
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
dist_class=dist_class,
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
tf.float32))
@@ -218,10 +220,6 @@ def __init__(self,
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
is_multidiscrete = True
output_hidden_shape = action_space.nvec.astype(np.int32)
elif self.config["vtrace"]:
raise UnsupportedSpaceException(
"Action space {} is not supported for APPO + VTrace.",
format(action_space))
else:
is_multidiscrete = False
output_hidden_shape = 1
@@ -365,6 +363,7 @@ def make_time_major(tensor, drop_last=False):
rewards=make_time_major(rewards, drop_last=True),
values=make_time_major(values, drop_last=True),
bootstrap_value=make_time_major(values)[-1],
dist_class=dist_class,
valid_mask=make_time_major(mask, drop_last=True),
vf_loss_coeff=self.config["vf_loss_coeff"],
entropy_coeff=self.config["entropy_coeff"],
@@ -0,0 +1,12 @@
pendulum-appo-vt:
env: Pendulum-v0
run: APPO
stop:
episode_reward_mean: -900 # just check it learns a bit
timesteps_total: 500000
config:
num_gpus: 0
num_workers: 1
gamma: 0.95
train_batch_size: 50
vtrace: true

0 comments on commit 7d5ef6d

Please sign in to comment.
You can’t perform that action at this time.