Permalink
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
93 lines (82 sloc) 3.12 KB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.agent import with_common_config
from ray.rllib.agents.dqn.dqn import DQNAgent
from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === QMix ===
# Mixing network. Either "qmix", "vdn", or None
"mixer": "qmix",
# Size of the mixing network embedding
"mixing_embed_dim": 32,
# Whether to use Double_Q learning
"double_q": True,
# Optimize over complete episodes by default.
"batch_mode": "complete_episodes",
# === Exploration ===
# Max num timesteps for annealing schedules. Exploration is annealed from
# 1.0 to exploration_fraction over this number of timesteps scaled by
# exploration_fraction
"schedule_max_timesteps": 100000,
# Number of env steps to optimize for before returning
"timesteps_per_iteration": 1000,
# Fraction of entire training period over which the exploration rate is
# annealed
"exploration_fraction": 0.1,
# Final value of random action probability
"exploration_final_eps": 0.02,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# === Replay buffer ===
# Size of the replay buffer in steps.
"buffer_size": 10000,
# === Optimization ===
# Learning rate for adam optimizer
"lr": 0.0005,
# RMSProp alpha
"optim_alpha": 0.99,
# RMSProp epsilon
"optim_eps": 0.00001,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": 10,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"sample_batch_size": 4,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 32,
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Optimizer class to use.
"optimizer_class": "SyncBatchReplayOptimizer",
# Whether to use a distribution of epsilons across workers for exploration.
"per_worker_exploration": False,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# === Model ===
"model": {
"lstm_cell_size": 64,
"max_seq_len": 999999,
},
})
# __sphinx_doc_end__
# yapf: enable
class QMixAgent(DQNAgent):
"""QMix implementation in PyTorch."""
_agent_name = "QMIX"
_default_config = DEFAULT_CONFIG
_policy_graph = QMixPolicyGraph
_optimizer_shared_configs = [
"learning_starts", "buffer_size", "train_batch_size"
]