Skip to content

Commit

Permalink
support dueling
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed Nov 22, 2018
1 parent 1ca2992 commit 78c765c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
10 changes: 9 additions & 1 deletion python/ray/rllib/agents/dqn/dqn_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self,
self.logits = support_logits_per_action
self.dist = support_prob_per_action
else:
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_mean = _reduce_mean_ignore_inf(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(
action_scores_mean, 1)
self.value = state_score + action_scores_centered
Expand Down Expand Up @@ -518,6 +518,14 @@ def _postprocess_dqn(policy_graph, sample_batch):
return batch


def _reduce_mean_ignore_inf(x, axis):
"""Same as tf.reduce_mean() but ignores -inf values."""
mask = tf.not_equal(x, tf.float32.min)
x_zeroed = tf.where(mask, x, tf.zeros_like(x))
return (tf.reduce_sum(x_zeroed, axis) / tf.reduce_sum(
tf.cast(mask, tf.float32), axis))


def _huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return tf.where(
Expand Down
5 changes: 2 additions & 3 deletions python/ray/rllib/examples/parametric_action_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
Note that since the model outputs now include "-inf" tf.float32.min
values, not all algorithm options are supported at the moment. For example,
dueling DQN will crash since it has a tf.reduce_mean() that is not robust to
the -inf action scores.
algorithms might crash if they don't properly ignore the -inf action scores.
Working configurations are given below.
"""

from __future__ import absolute_import
Expand Down Expand Up @@ -176,7 +176,6 @@ def _build_layers_v2(self, input_dict, num_outputs, options):
elif args.run == "DQN":
cfg = {
"hiddens": [], # don't postprocess the action scores
"dueling": False, # doesn't support action masking
}
else:
cfg = {}
Expand Down

0 comments on commit 78c765c

Please sign in to comment.