Skip to content

Commit

Permalink
[rllib] implemented compute_advantages without gae (#6941)
Browse files Browse the repository at this point in the history
  • Loading branch information
roireshef committed Feb 1, 2020
1 parent 92525f3 commit 3c60caa
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 18 deletions.
7 changes: 7 additions & 0 deletions rllib/agents/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# Should use a critic as a baseline (otherwise don't use value baseline;
# required for using GAE).
"use_critic": True,
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,

# Size of rollout batch
"sample_batch_size": 10,
# GAE(gamma) parameter
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/a3c/a3c_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def postprocess_advantages(policy,
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"])
return compute_advantages(
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
policy.config["use_gae"], policy.config["use_critic"])


def add_value_function_fetch(policy):
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def add_advantages(policy,
last_r = 0.0
else:
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"])
return compute_advantages(
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
policy.config["use_gae"], policy.config["use_critic"])


def model_value_predictions(policy, input_dict, state_batches, model,
Expand Down
6 changes: 5 additions & 1 deletion rllib/agents/marwil/marwil_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def postprocess_advantages(policy,
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(
sample_batch, last_r, policy.config["gamma"], use_gae=False)
sample_batch,
last_r,
policy.config["gamma"],
use_gae=False,
use_critic=False)


class MARWILLoss(object):
Expand Down
17 changes: 12 additions & 5 deletions rllib/agents/pg/pg_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,26 @@
tf = try_import_tf()


def post_process_advantages(policy, sample_batch, other_agent_batches=None,
def post_process_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
"""This adds the "advantages" column to the sample train_batch."""
return compute_advantages(sample_batch, 0.0, policy.config["gamma"],
use_gae=False)
return compute_advantages(
sample_batch,
0.0,
policy.config["gamma"],
use_gae=False,
use_critic=False)


def pg_tf_loss(policy, model, dist_class, train_batch):
"""The basic policy gradients loss."""
logits, _ = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
return -tf.reduce_mean(action_dist.logp(train_batch[SampleBatch.ACTIONS])
* train_batch[Postprocessing.ADVANTAGES])
return -tf.reduce_mean(
action_dist.logp(train_batch[SampleBatch.ACTIONS]) *
train_batch[Postprocessing.ADVANTAGES])


PGTFPolicy = build_tf_policy(
Expand Down
3 changes: 3 additions & 0 deletions rllib/agents/ppo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
"vtrace": False,

# == These two options only apply if vtrace: False ==
# Should use a critic as a baseline (otherwise don't use value baseline;
# required for using GAE).
"use_critic": True,
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/ppo/appo_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def postprocess_trajectory(policy,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])
use_gae=policy.config["use_gae"],
use_critic=policy.config["use_critic"])
else:
batch = sample_batch
del batch.data["new_obs"] # not used, so save some bandwidth
Expand Down
4 changes: 4 additions & 0 deletions rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# Should use a critic as a baseline (otherwise don't use value baseline;
# required for using GAE).
"use_critic": True,
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,

# The GAE(lambda) parameter.
"lambda": 1.0,
# Initial coefficient for KL divergence.
Expand Down
31 changes: 25 additions & 6 deletions rllib/evaluation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ class Postprocessing:


@DeveloperAPI
def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
def compute_advantages(rollout,
last_r,
gamma=0.9,
lambda_=1.0,
use_gae=True,
use_critic=True):
"""
Given a rollout, compute its value targets and the advantage.
Expand All @@ -26,6 +31,8 @@ def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
gamma (float): Discount factor.
lambda_ (float): Parameter for GAE
use_gae (bool): Using Generalized Advantage Estimation
use_critic (bool): Whether to use critic (value estimates). Setting
this to False will use 0 as baseline.
Returns:
SampleBatch (SampleBatch): Object with experience from rollout and
Expand All @@ -37,8 +44,12 @@ def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
for key in rollout:
traj[key] = np.stack(rollout[key])

assert SampleBatch.VF_PREDS in rollout or not use_critic, \
"use_critic=True but values not found"
assert use_critic or not use_gae, \
"Can't use gae without using a value function"

if use_gae:
assert SampleBatch.VF_PREDS in rollout, "Values not found!"
vpred_t = np.concatenate(
[rollout[SampleBatch.VF_PREDS],
np.array([last_r])])
Expand All @@ -54,10 +65,18 @@ def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
rewards_plus_v = np.concatenate(
[rollout[SampleBatch.REWARDS],
np.array([last_r])])
traj[Postprocessing.ADVANTAGES] = discount(rewards_plus_v, gamma)[:-1]
# TODO(ekl): support using a critic without GAE
traj[Postprocessing.VALUE_TARGETS] = np.zeros_like(
traj[Postprocessing.ADVANTAGES])
discounted_returns = discount(rewards_plus_v,
gamma)[:-1].copy().astype(np.float32)

if use_critic:
traj[Postprocessing.
ADVANTAGES] = discounted_returns - rollout[SampleBatch.
VF_PREDS]
traj[Postprocessing.VALUE_TARGETS] = discounted_returns
else:
traj[Postprocessing.ADVANTAGES] = discounted_returns
traj[Postprocessing.VALUE_TARGETS] = np.zeros_like(
traj[Postprocessing.ADVANTAGES])

traj[Postprocessing.ADVANTAGES] = traj[
Postprocessing.ADVANTAGES].copy().astype(np.float32)
Expand Down
3 changes: 2 additions & 1 deletion rllib/tests/test_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def postprocess_trajectory(self,
other_agent_batches=None,
episode=None):
assert episode is not None
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
return compute_advantages(
batch, 100.0, 0.9, use_gae=False, use_critic=False)


class BadPolicy(MockPolicy):
Expand Down

0 comments on commit 3c60caa

Please sign in to comment.