/
dqn_tf_policy.py
431 lines (365 loc) · 16.7 KB
/
dqn_tf_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
"""TensorFlow policy class used for DQN"""
from typing import Dict
import gym
import numpy as np
import ray
from ray.rllib.agents.dqn.distributional_q_tf_model import \
DistributionalQTFModel
from ray.rllib.agents.dqn.simple_q_tf_policy import TargetNetworkMixin
from ray.rllib.evaluation.postprocessing import adjust_nstep
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration import ParameterNoise
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.tf_utils import (
huber_loss, make_tf_callable, minimize_and_clip, reduce_mean_ignore_inf)
from ray.rllib.utils.typing import (ModelGradients, TensorType,
TrainerConfigDict)
tf1, tf, tfv = try_import_tf()
Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"
# Importance sampling weights for prioritized replay
PRIO_WEIGHTS = "weights"
class QLoss:
def __init__(self,
q_t_selected: TensorType,
q_logits_t_selected: TensorType,
q_tp1_best: TensorType,
q_dist_tp1_best: TensorType,
importance_weights: TensorType,
rewards: TensorType,
done_mask: TensorType,
gamma: float = 0.99,
n_step: int = 1,
num_atoms: int = 1,
v_min: float = -10.0,
v_max: float = 10.0):
if num_atoms > 1:
# Distributional Q-learning which corresponds to an entropy loss
z = tf.range(num_atoms, dtype=tf.float32)
z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
# (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
r_tau = tf.expand_dims(
rewards, -1) + gamma**n_step * tf.expand_dims(
1.0 - done_mask, -1) * tf.expand_dims(z, 0)
r_tau = tf.clip_by_value(r_tau, v_min, v_max)
b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
lb = tf.floor(b)
ub = tf.math.ceil(b)
# indispensable judgement which is missed in most implementations
# when b happens to be an integer, lb == ub, so pr_j(s', a*) will
# be discarded because (ub-b) == (b-lb) == 0
floor_equal_ceil = tf.cast(tf.less(ub - lb, 0.5), tf.float32)
l_project = tf.one_hot(
tf.cast(lb, dtype=tf.int32),
num_atoms) # (batch_size, num_atoms, num_atoms)
u_project = tf.one_hot(
tf.cast(ub, dtype=tf.int32),
num_atoms) # (batch_size, num_atoms, num_atoms)
ml_delta = q_dist_tp1_best * (ub - b + floor_equal_ceil)
mu_delta = q_dist_tp1_best * (b - lb)
ml_delta = tf.reduce_sum(
l_project * tf.expand_dims(ml_delta, -1), axis=1)
mu_delta = tf.reduce_sum(
u_project * tf.expand_dims(mu_delta, -1), axis=1)
m = ml_delta + mu_delta
# Rainbow paper claims that using this cross entropy loss for
# priority is robust and insensitive to `prioritized_replay_alpha`
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
labels=m, logits=q_logits_t_selected)
self.loss = tf.reduce_mean(
self.td_error * tf.cast(importance_weights, tf.float32))
self.stats = {
# TODO: better Q stats for dist dqn
"mean_td_error": tf.reduce_mean(self.td_error),
}
else:
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
# compute the error (potentially clipped)
self.td_error = (
q_t_selected - tf.stop_gradient(q_t_selected_target))
self.loss = tf.reduce_mean(
tf.cast(importance_weights, tf.float32) * huber_loss(
self.td_error))
self.stats = {
"mean_q": tf.reduce_mean(q_t_selected),
"min_q": tf.reduce_min(q_t_selected),
"max_q": tf.reduce_max(q_t_selected),
"mean_td_error": tf.reduce_mean(self.td_error),
}
class ComputeTDErrorMixin:
"""Assign the `compute_td_error` method to the DQNTFPolicy
This allows us to prioritize on the worker side.
"""
def __init__(self):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
# Do forward pass on loss to update td error attribute
build_q_losses(
self, self.model, None, {
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
})
return self.q_loss.td_error
self.compute_td_error = compute_td_error
def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> ModelV2:
"""Build q_model and target_model for DQN
Args:
policy (Policy): The Policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
Returns:
ModelV2: The Model for the Policy to use.
Note: The target q model will not be returned, just assigned to
`policy.target_model`.
"""
if not isinstance(action_space, gym.spaces.Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(action_space))
if config["hiddens"]:
# try to infer the last layer size, otherwise fall back to 256
num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
config["model"]["no_final_linear"] = True
else:
num_outputs = action_space.n
q_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=config["model"],
framework="tf",
model_interface=DistributionalQTFModel,
name=Q_SCOPE,
num_atoms=config["num_atoms"],
dueling=config["dueling"],
q_hiddens=config["hiddens"],
use_noisy=config["noisy"],
v_min=config["v_min"],
v_max=config["v_max"],
sigma0=config["sigma0"],
# TODO(sven): Move option to add LayerNorm after each Dense
# generically into ModelCatalog.
add_layer_norm=isinstance(
getattr(policy, "exploration", None), ParameterNoise)
or config["exploration_config"]["type"] == "ParameterNoise")
policy.target_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=config["model"],
framework="tf",
model_interface=DistributionalQTFModel,
name=Q_TARGET_SCOPE,
num_atoms=config["num_atoms"],
dueling=config["dueling"],
q_hiddens=config["hiddens"],
use_noisy=config["noisy"],
v_min=config["v_min"],
v_max=config["v_max"],
sigma0=config["sigma0"],
# TODO(sven): Move option to add LayerNorm after each Dense
# generically into ModelCatalog.
add_layer_norm=isinstance(
getattr(policy, "exploration", None), ParameterNoise)
or config["exploration_config"]["type"] == "ParameterNoise")
return q_model
def get_distribution_inputs_and_class(policy: Policy,
model: ModelV2,
input_dict: SampleBatch,
*,
explore=True,
**kwargs):
q_vals = compute_q_values(
policy, model, input_dict, state_batches=None, explore=explore)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
policy.q_values = q_vals
return policy.q_values, Categorical, [] # state-out
def build_q_losses(policy: Policy, model, _,
train_batch: SampleBatch) -> TensorType:
"""Constructs the loss for DQNTFPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
config = policy.config
# q network evaluation
q_t, q_logits_t, q_dist_t, _ = compute_q_values(
policy,
model,
SampleBatch({
"obs": train_batch[SampleBatch.CUR_OBS]
}),
state_batches=None,
explore=False)
# target q network evalution
q_tp1, q_logits_tp1, q_dist_tp1, _ = compute_q_values(
policy,
policy.target_model,
SampleBatch({
"obs": train_batch[SampleBatch.NEXT_OBS]
}),
state_batches=None,
explore=False)
if not hasattr(policy, "target_q_func_vars"):
policy.target_q_func_vars = policy.target_model.variables()
# q scores for actions which we know were selected in the given state.
one_hot_selection = tf.one_hot(
tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32),
policy.action_space.n)
q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
q_logits_t_selected = tf.reduce_sum(
q_logits_t * tf.expand_dims(one_hot_selection, -1), 1)
# compute estimate of best possible value starting from state at t + 1
if config["double_q"]:
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net, _ = compute_q_values(
policy, model,
SampleBatch({"obs": train_batch[SampleBatch.NEXT_OBS]}),
state_batches=None,
explore=False)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net,
policy.action_space.n)
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_dist_tp1_best = tf.reduce_sum(
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
else:
q_tp1_best_one_hot_selection = tf.one_hot(
tf.argmax(q_tp1, 1), policy.action_space.n)
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_dist_tp1_best = tf.reduce_sum(
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
policy.q_loss = QLoss(
q_t_selected, q_logits_t_selected, q_tp1_best, q_dist_tp1_best,
train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS],
tf.cast(train_batch[SampleBatch.DONES],
tf.float32), config["gamma"], config["n_step"],
config["num_atoms"], config["v_min"], config["v_max"])
return policy.q_loss.loss
def adam_optimizer(policy: Policy, config: TrainerConfigDict
) -> "tf.keras.optimizers.Optimizer":
if policy.config["framework"] in ["tf2", "tfe"]:
return tf.keras.optimizers.Adam(
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
else:
return tf1.train.AdamOptimizer(
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer",
loss: TensorType) -> ModelGradients:
if not hasattr(policy, "q_func_vars"):
policy.q_func_vars = policy.model.variables()
return minimize_and_clip(
optimizer,
loss,
var_list=policy.q_func_vars,
clip_val=policy.config["grad_clip"])
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
return dict({
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
}, **policy.q_loss.stats)
def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None:
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
ComputeTDErrorMixin.__init__(policy)
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
def compute_q_values(policy: Policy,
model: ModelV2,
input_batch: SampleBatch,
state_batches=None,
seq_lens=None,
explore=None,
is_training: bool = False):
config = policy.config
model_out, state = model(input_batch, state_batches or [], seq_lens)
if config["num_atoms"] > 1:
(action_scores, z, support_logits_per_action, logits,
dist) = model.get_q_value_distributions(model_out)
else:
(action_scores, logits,
dist) = model.get_q_value_distributions(model_out)
if config["dueling"]:
state_score = model.get_state_value(model_out)
if config["num_atoms"] > 1:
support_logits_per_action_mean = tf.reduce_mean(
support_logits_per_action, 1)
support_logits_per_action_centered = (
support_logits_per_action - tf.expand_dims(
support_logits_per_action_mean, 1))
support_logits_per_action = tf.expand_dims(
state_score, 1) + support_logits_per_action_centered
support_prob_per_action = tf.nn.softmax(
logits=support_logits_per_action)
value = tf.reduce_sum(
input_tensor=z * support_prob_per_action, axis=-1)
logits = support_logits_per_action
dist = support_prob_per_action
else:
action_scores_mean = reduce_mean_ignore_inf(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(
action_scores_mean, 1)
value = state_score + action_scores_centered
else:
value = action_scores
return value, logits, dist, state
def postprocess_nstep_and_prio(policy: Policy,
batch: SampleBatch,
other_agent=None,
episode=None) -> SampleBatch:
# N-step Q adjustments.
if policy.config["n_step"] > 1:
adjust_nstep(policy.config["n_step"], policy.config["gamma"], batch)
# Create dummy prio-weights (1.0) in case we don't have any in
# the batch.
if PRIO_WEIGHTS not in batch:
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
# Prioritize on the worker side.
if batch.count > 0 and policy.config["worker_side_prioritization"]:
td_errors = policy.compute_td_error(
batch[SampleBatch.OBS], batch[SampleBatch.ACTIONS],
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
new_priorities = (np.abs(convert_to_numpy(td_errors)) +
policy.config["prioritized_replay_eps"])
batch[PRIO_WEIGHTS] = new_priorities
return batch
DQNTFPolicy = build_tf_policy(
name="DQNTFPolicy",
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
make_model=build_q_model,
action_distribution_fn=get_distribution_inputs_and_class,
loss_fn=build_q_losses,
stats_fn=build_q_stats,
postprocess_fn=postprocess_nstep_and_prio,
optimizer_fn=adam_optimizer,
compute_gradients_fn=clip_gradients,
extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
before_loss_init=setup_mid_mixins,
after_init=setup_late_mixins,
mixins=[
TargetNetworkMixin,
ComputeTDErrorMixin,
LearningRateSchedule,
])