Permalink
8b5827b Dec 9, 2018
1 contributor

Users who have contributed to this file

57 lines (49 sloc) 1.88 KB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
# yapf: disable
# __sphinx_doc_begin__
APEX_DEFAULT_CONFIG = merge_dicts(
DQN_CONFIG, # see also the options in dqn.py, which are also supported
{
"optimizer_class": "AsyncReplayOptimizer",
"optimizer": merge_dicts(
DQN_CONFIG["optimizer"], {
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"debug": False
}),
"n_step": 3,
"num_gpus": 1,
"num_workers": 32,
"buffer_size": 2000000,
"learning_starts": 50000,
"train_batch_size": 512,
"sample_batch_size": 50,
"target_network_update_freq": 500000,
"timesteps_per_iteration": 25000,
"per_worker_exploration": True,
"worker_side_prioritization": True,
"min_iter_time_s": 30,
},
)
# __sphinx_doc_end__
# yapf: enable
class ApexAgent(DQNAgent):
"""DQN variant that uses the Ape-X distributed policy optimizer.
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_agent_name = "APEX"
_default_config = APEX_DEFAULT_CONFIG
@override(DQNAgent)
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1