diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 66bf08a6c3999..1d0501215745c 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -133,10 +133,10 @@ Tuned examples: `Pendulum-v0 `__ `[implementation] `__ -RLlib DQN is implemented using the SyncReplayOptimizer. The algorithm can be scaled by increasing the number of workers, using the AsyncGradientsOptimizer for async DQN, or using Ape-X. Memory usage is reduced by compressing samples in the replay buffer with LZ4. All of the DQN improvements evaluated in `Rainbow `__ are available, though not all are enabled by default. +RLlib DQN is implemented using the SyncReplayOptimizer. The algorithm can be scaled by increasing the number of workers, using the AsyncGradientsOptimizer for async DQN, or using Ape-X. Memory usage is reduced by compressing samples in the replay buffer with LZ4. All of the DQN improvements evaluated in `Rainbow `__ are available, though not all are enabled by default. See also how to use `parametric-actions in DQN `__. Tuned examples: `PongDeterministic-v4 `__, `Rainbow configuration `__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 `__, `with Dueling and Double-Q `__, `with Distributional DQN `__. diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index c1381f561cd4c..ca36186e1a5f6 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -7,20 +7,22 @@ RLlib works with several different types of environments, including `OpenAI Gym **Compatibility matrix**: -============= ================ ================== =========== ================== -Algorithm Discrete Actions Continuous Actions Multi-Agent Recurrent Policies -============= ================ ================== =========== ================== -A2C, A3C **Yes** **Yes** **Yes** **Yes** -PPO **Yes** **Yes** **Yes** **Yes** -PG **Yes** **Yes** **Yes** **Yes** -IMPALA **Yes** No **Yes** **Yes** -DQN, Rainbow **Yes** No **Yes** No -DDPG, TD3 No **Yes** **Yes** No -APEX-DQN **Yes** No **Yes** No -APEX-DDPG No **Yes** **Yes** No -ES **Yes** **Yes** No No -ARS **Yes** **Yes** No No -============= ================ ================== =========== ================== +============= ======================= ================== =========== ================== +Algorithm Discrete Actions Continuous Actions Multi-Agent Recurrent Policies +============= ======================= ================== =========== ================== +A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** **Yes** +PPO **Yes** `+parametric`_ **Yes** **Yes** **Yes** +PG **Yes** `+parametric`_ **Yes** **Yes** **Yes** +IMPALA **Yes** `+parametric`_ No **Yes** **Yes** +DQN, Rainbow **Yes** `+parametric`_ No **Yes** No +DDPG, TD3 No **Yes** **Yes** No +APEX-DQN **Yes** `+parametric`_ No **Yes** No +APEX-DDPG No **Yes** **Yes** No +ES **Yes** **Yes** No No +ARS **Yes** **Yes** No No +============= ======================= ================== =========== ================== + +.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces In the high-level agent APIs, environments are identified with string names. By default, the string will be interpreted as a gym `environment name `__, however you can also register custom environments by name: diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 5fde37f53087d..883b0a6bb5266 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -110,6 +110,43 @@ Custom models should subclass the common RLlib `model class `__ and associated `training scripts `__. You can also reference the `unit tests `__ for Tuple and Dict spaces, which show how to access nested observation fields. +Custom Recurrent Models +~~~~~~~~~~~~~~~~~~~~~~~ + +Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. The only difference from a normal custom model is that you have to define ``self.state_init``, ``self.state_in``, and ``self.state_out``. You can refer to the existing `lstm.py `__ model as an example to implement your own model: + + +.. code-block:: python + + class MyCustomLSTM(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + # Some initial layers to process inputs, shape [BATCH, OBS...]. + features = some_hidden_layers(input_dict["obs"]) + + # Add back the nested time dimension for tf.dynamic_rnn, new shape + # will be [BATCH, MAX_SEQ_LEN, OBS...]. + last_layer = add_time_dimension(features, self.seq_lens) + + # Setup the LSTM cell (see lstm.py for an example) + lstm = rnn.BasicLSTMCell(256, state_is_tuple=True) + self.state_init = ... + self.state_in = ... + lstm_out, lstm_state = tf.nn.dynamic_rnn( + lstm, + last_layer, + initial_state=..., + sequence_length=self.seq_lens, + time_major=False, + dtype=tf.float32) + self.state_out = list(lstm_state) + + # Drop the time dimension again so back to shape [BATCH, OBS...]. + # Note that we retain the zero padding (see issue #2992). + last_layer = tf.reshape(lstm_out, [-1, cell_size]) + logits = linear(last_layer, num_outputs, "action", + normc_initializer(0.01)) + return logits, last_layer + Custom Preprocessors -------------------- @@ -188,6 +225,53 @@ Then, you can create an agent with your custom policy graph by: In this example we overrode existing methods of the existing DDPG policy graph, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely. +Variable-length / Parametric Action Spaces +------------------------------------------ + +Custom models can be used to work with environments where (1) the set of valid actions varies per step, and/or (2) the number of valid actions is very large, as in `OpenAI Five `__ and `Horizon `__. The general idea is that the meaning of actions can be completely conditioned on the observation, that is, the ``a`` in ``Q(s, a)`` is just a token in ``[0, MAX_AVAIL_ACTIONS)`` that only has meaning in the context of ``s``. This works with algorithms in the `DQN and policy-gradient families `__ and can be implemented as follows: + +1. The environment should return a mask and/or list of valid action embeddings as part of the observation for each step. To enable batching, the number of actions can be allowed to vary from 1 to some max number: + +.. code-block:: python + + class MyParamActionEnv(gym.Env): + def __init__(self, max_avail_actions): + self.action_space = Discrete(max_avail_actions) + self.observation_space = Dict({ + "action_mask": Box(0, 1, shape=(max_avail_actions, )), + "avail_actions": Box(-1, 1, shape=(max_avail_actions, action_embedding_sz)), + "real_obs": ..., + }) + +2. A custom model can be defined that can interpret the ``action_mask`` and ``avail_actions`` portions of the observation. Here the model computes the action logits via the dot product of some network output and each action embedding. Invalid actions can be masked out of the softmax by scaling the probability to zero: + +.. code-block:: python + + class MyParamActionModel(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + avail_actions = input_dict["obs"]["avail_actions"] + action_mask = input_dict["obs"]["action_mask"] + + output = FullyConnectedNetwork( + input_dict["obs"]["real_obs"], num_outputs=action_embedding_sz) + + # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the + # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. + intent_vector = tf.expand_dims(output, 1) + + # Shape of logits is [BATCH, MAX_ACTIONS]. + action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2) + + # Mask out invalid actions (use tf.float32.min for stability) + inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) + masked_logits = inf_mask + action_logits + + return masked_logits, last_layer + + +Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py `__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN and several policy gradient algorithms. + + Model-Based Rollouts -------------------- diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 6b1366f4ee081..dc37d22943ba7 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -73,13 +73,13 @@ In an example below, we train A2C by specifying 8 workers through the config fla python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \ --run=A2C --config '{"num_workers": 8}' -.. image:: rllib-config.svg - Specifying Resources ~~~~~~~~~~~~~~~~~~~~ You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five agents onto one GPU by setting ``num_gpus: 0.2``. Note that in Ray < 0.6.0 fractional GPU support requires setting the environment variable ``RAY_USE_XRAY=1``. +.. image:: rllib-config.svg + Common Parameters ~~~~~~~~~~~~~~~~~ diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 2de444b52965b..e96bd6fccbcb9 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -56,7 +56,7 @@ Algorithms - `Deep Deterministic Policy Gradients (DDPG, TD3) `__ - - `Deep Q Networks (DQN, Rainbow) `__ + - `Deep Q Networks (DQN, Rainbow, Parametric DQN) `__ - `Policy Gradients `__ @@ -75,6 +75,7 @@ Models and Preprocessors * `Custom Models `__ * `Custom Preprocessors `__ * `Customizing Policy Graphs `__ +* `Variable-length / Parametric Action Spaces `__ * `Model-Based Rollouts `__ RLlib Concepts diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index 6125cd9d387b7..2bbff99246f3c 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -30,16 +30,21 @@ def __init__(self, sigma0=0.5): self.model = model with tf.variable_scope("action_value"): - action_out = model.last_layer - for i in range(len(hiddens)): - if use_noisy: - action_out = self.noisy_layer("hidden_%d" % i, action_out, - hiddens[i], sigma0) - else: - action_out = layers.fully_connected( - action_out, - num_outputs=hiddens[i], - activation_fn=tf.nn.relu) + if hiddens: + action_out = model.last_layer + for i in range(len(hiddens)): + if use_noisy: + action_out = self.noisy_layer( + "hidden_%d" % i, action_out, hiddens[i], sigma0) + else: + action_out = layers.fully_connected( + action_out, + num_outputs=hiddens[i], + activation_fn=tf.nn.relu) + else: + # Avoid postprocessing the outputs. This enables custom models + # to be used for parametric action DQN. + action_out = model.outputs if use_noisy: action_scores = self.noisy_layer( "output", @@ -47,11 +52,13 @@ def __init__(self, num_actions * num_atoms, sigma0, non_linear=False) - else: + elif hiddens: action_scores = layers.fully_connected( action_out, num_outputs=num_actions * num_atoms, activation_fn=None) + else: + action_scores = model.outputs if num_atoms > 1: # Distributional Q-learning uses a discrete support z # to represent the action value distribution @@ -107,7 +114,7 @@ def __init__(self, self.logits = support_logits_per_action self.dist = support_prob_per_action else: - action_scores_mean = tf.reduce_mean(action_scores, 1) + action_scores_mean = _reduce_mean_ignore_inf(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( action_scores_mean, 1) self.value = state_score + action_scores_centered @@ -176,11 +183,15 @@ class QValuePolicy(object): def __init__(self, q_values, observations, num_actions, stochastic, eps): deterministic_actions = tf.argmax(q_values, axis=1) batch_size = tf.shape(observations)[0] - random_actions = tf.random_uniform( - tf.stack([batch_size]), - minval=0, - maxval=num_actions, - dtype=tf.int64) + + # Special case masked out actions (q_value ~= -inf) so that we don't + # even consider them for exploration. + random_valid_action_logits = tf.where( + tf.equal(q_values, tf.float32.min), + tf.ones_like(q_values) * tf.float32.min, tf.ones_like(q_values)) + random_actions = tf.squeeze( + tf.multinomial(random_valid_action_logits, 1), axis=1) + chose_random = tf.random_uniform( tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps stochastic_actions = tf.where(chose_random, random_actions, @@ -368,8 +379,8 @@ def _build_q_network(self, obs, space): qnet = QNetwork( ModelCatalog.get_model({ "obs": obs - }, space, 1, self.config["model"]), self.num_actions, - self.config["dueling"], self.config["hiddens"], + }, space, self.num_actions, self.config["model"]), + self.num_actions, self.config["dueling"], self.config["hiddens"], self.config["noisy"], self.config["num_atoms"], self.config["v_min"], self.config["v_max"], self.config["sigma0"]) return qnet.value, qnet.logits, qnet.dist, qnet.model @@ -507,6 +518,14 @@ def _postprocess_dqn(policy_graph, sample_batch): return batch +def _reduce_mean_ignore_inf(x, axis): + """Same as tf.reduce_mean() but ignores -inf values.""" + mask = tf.not_equal(x, tf.float32.min) + x_zeroed = tf.where(mask, x, tf.zeros_like(x)) + return (tf.reduce_sum(x_zeroed, axis) / tf.reduce_sum( + tf.cast(mask, tf.float32), axis)) + + def _huber_loss(x, delta=1.0): """Reference: https://en.wikipedia.org/wiki/Huber_loss""" return tf.where( diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index eb556877c5a70..722f9263d8167 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -110,6 +110,11 @@ def _validate_config(self): and not self.config["simple_optimizer"]): logger.warn("forcing simple_optimizer=True in multi-agent mode") self.config["simple_optimizer"] = True + if self.config["observation_filter"] != "NoFilter": + # TODO(ekl): consider setting the default to be NoFilter + logger.warn( + "By default, observations will be normalized with {}".format( + self.config["observation_filter"])) def _train(self): prev_steps = self.optimizer.num_steps_sampled diff --git a/python/ray/rllib/examples/parametric_action_cartpole.py b/python/ray/rllib/examples/parametric_action_cartpole.py new file mode 100644 index 0000000000000..a1438f0a24123 --- /dev/null +++ b/python/ray/rllib/examples/parametric_action_cartpole.py @@ -0,0 +1,196 @@ +"""Example of handling variable length and/or parametric action spaces. + +This is a toy example of the action-embedding based approach for handling large +discrete action spaces (potentially infinite in size), similar to how +OpenAI Five works: + + https://neuro.cs.ut.ee/the-use-of-embeddings-in-openai-five/ + +This currently works with RLlib's policy gradient style algorithms +(e.g., PG, PPO, IMPALA, A2C) and also DQN. + +Note that since the model outputs now include "-inf" tf.float32.min +values, not all algorithm options are supported at the moment. For example, +algorithms might crash if they don't properly ignore the -inf action scores. +Working configurations are given below. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import random +import numpy as np +import gym +from gym.spaces import Box, Discrete, Dict +import tensorflow as tf +import tensorflow.contrib.slim as slim + +import ray +from ray.rllib.models import Model, ModelCatalog +from ray.rllib.models.misc import normc_initializer +from ray.tune import run_experiments +from ray.tune.registry import register_env + +parser = argparse.ArgumentParser() +parser.add_argument("--stop", type=int, default=200) +parser.add_argument("--run", type=str, default="PPO") + + +class ParametricActionCartpole(gym.Env): + """Parametric action version of CartPole. + + In this env there are only ever two valid actions, but we pretend there are + actually up to `max_avail_actions` actions that can be taken, and the two + valid actions are randomly hidden among this set. + + At each step, we emit a dict of: + - the actual cart observation + - a mask of valid actions (e.g., [0, 0, 1, 0, 0, 1] for 6 max avail) + - the list of action embeddings (w/ zeroes for invalid actions) (e.g., + [[0, 0], + [0, 0], + [-0.2322, -0.2569], + [0, 0], + [0, 0], + [0.7878, 1.2297]] for max_avail_actions=6) + + In a real environment, the actions embeddings would be larger than two + units of course, and also there would be a variable number of valid actions + per step instead of always [LEFT, RIGHT]. + """ + + def __init__(self, max_avail_actions): + # Use simple random 2-unit action embeddings for [LEFT, RIGHT] + self.left_action_embed = np.random.randn(2) + self.right_action_embed = np.random.randn(2) + self.action_space = Discrete(max_avail_actions) + self.wrapped = gym.make("CartPole-v0") + self.observation_space = Dict({ + "action_mask": Box(0, 1, shape=(max_avail_actions, )), + "avail_actions": Box(-1, 1, shape=(max_avail_actions, 2)), + "cart": self.wrapped.observation_space, + }) + + def update_avail_actions(self): + self.action_assignments = [[0, 0]] * self.action_space.n + self.action_mask = [0] * self.action_space.n + self.left_idx, self.right_idx = random.sample( + range(self.action_space.n), 2) + self.action_assignments[self.left_idx] = self.left_action_embed + self.action_assignments[self.right_idx] = self.right_action_embed + self.action_mask[self.left_idx] = 1 + self.action_mask[self.right_idx] = 1 + + def reset(self): + self.update_avail_actions() + return { + "action_mask": self.action_mask, + "avail_actions": self.action_assignments, + "cart": self.wrapped.reset(), + } + + def step(self, action): + if action == self.left_idx: + actual_action = 0 + elif action == self.right_idx: + actual_action = 1 + else: + raise ValueError( + "Chosen action was not one of the non-zero action embeddings", + action, self.action_assignments, self.action_mask, + self.left_idx, self.right_idx) + orig_obs, rew, done, info = self.wrapped.step(actual_action) + self.update_avail_actions() + obs = { + "action_mask": self.action_mask, + "avail_actions": self.action_assignments, + "cart": orig_obs, + } + return obs, rew, done, info + + +class ParametricActionsModel(Model): + """Parametric action model that handles the dot product and masking. + + This assumes the outputs are logits for a single Categorical action dist. + Getting this to work with a more complex output (e.g., if the action space + is a tuple of several distributions) is also possible but left as an + exercise to the reader. + """ + + def _build_layers_v2(self, input_dict, num_outputs, options): + # Extract the available actions tensor from the observation. + avail_actions = input_dict["obs"]["avail_actions"] + action_mask = input_dict["obs"]["action_mask"] + action_embed_size = avail_actions.shape[2].value + if num_outputs != avail_actions.shape[1].value: + raise ValueError( + "This model assumes num outputs is equal to max avail actions", + num_outputs, avail_actions) + + # Standard FC net component. + last_layer = input_dict["obs"]["cart"] + hiddens = [256, 256] + for i, size in enumerate(hiddens): + label = "fc{}".format(i) + last_layer = slim.fully_connected( + last_layer, + size, + weights_initializer=normc_initializer(1.0), + activation_fn=tf.nn.tanh, + scope=label) + output = slim.fully_connected( + last_layer, + action_embed_size, + weights_initializer=normc_initializer(0.01), + activation_fn=None, + scope="fc_out") + + # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the + # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. + intent_vector = tf.expand_dims(output, 1) + + # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS]. + action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2) + + # Mask out invalid actions (use tf.float32.min for stability) + inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) + masked_logits = inf_mask + action_logits + + return masked_logits, last_layer + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init() + + ModelCatalog.register_custom_model("pa_model", ParametricActionsModel) + register_env("pa_cartpole", lambda _: ParametricActionCartpole(10)) + if args.run == "PPO": + cfg = { + "observation_filter": "NoFilter", # don't filter the action list + "vf_share_layers": True, # don't create duplicate value model + } + elif args.run == "DQN": + cfg = { + "hiddens": [], # don't postprocess the action scores + } + else: + cfg = {} + run_experiments({ + "parametric_cartpole": { + "run": args.run, + "env": "pa_cartpole", + "stop": { + "episode_reward_mean": args.stop, + }, + "config": dict({ + "model": { + "custom_model": "pa_model", + }, + "num_workers": 0, + }, **cfg), + }, + }) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 8f0b8ac82540e..63a7e73890ccf 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -217,7 +217,7 @@ def _get_model(input_dict, obs_space, num_outputs, options, state_in, seq_lens): if options.get("custom_model"): model = options["custom_model"] - logger.info("Using custom model {}".format(model)) + logger.debug("Using custom model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( input_dict, obs_space, diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index 074fda29b96a3..a4af708b79151 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +from collections import OrderedDict import cv2 import logging import numpy as np @@ -164,6 +165,8 @@ def _init_shape(self, obs_space, options): return (size, ) def transform(self, observation): + if not isinstance(observation, OrderedDict): + observation = OrderedDict(sorted(list(observation.items()))) assert len(observation) == len(self.preprocessors), \ (len(observation), len(self.preprocessors)) return np.concatenate([ diff --git a/python/ray/rllib/tuned_examples/atari-dist-dqn.yaml b/python/ray/rllib/tuned_examples/atari-dist-dqn.yaml index 57cd5635d78b4..d351e403f2e23 100644 --- a/python/ray/rllib/tuned_examples/atari-dist-dqn.yaml +++ b/python/ray/rllib/tuned_examples/atari-dist-dqn.yaml @@ -27,5 +27,5 @@ basic-dqn: prioritized_replay_alpha: 0.5 beta_annealing_fraction: 1.0 final_prioritized_replay_beta: 1.0 - num_gpus: 1 + num_gpus: 0.2 timesteps_per_iteration: 10000 diff --git a/python/ray/rllib/tuned_examples/atari-dqn.yaml b/python/ray/rllib/tuned_examples/atari-dqn.yaml index 264ddfd27b413..b8731bb054ef3 100644 --- a/python/ray/rllib/tuned_examples/atari-dqn.yaml +++ b/python/ray/rllib/tuned_examples/atari-dqn.yaml @@ -1,4 +1,4 @@ -# Runs on a single g3.16xl node +# Runs on a single g3.4xl node # See https://github.com/ray-project/rl-experiments for results atari-basic-dqn: env: @@ -29,5 +29,5 @@ atari-basic-dqn: prioritized_replay_alpha: 0.5 beta_annealing_fraction: 1.0 final_prioritized_replay_beta: 1.0 - num_gpus: 1 + num_gpus: 0.2 timesteps_per_iteration: 10000 diff --git a/python/ray/rllib/tuned_examples/atari-duel-ddqn.yaml b/python/ray/rllib/tuned_examples/atari-duel-ddqn.yaml index be59d15ba8070..b5a13162b61e4 100644 --- a/python/ray/rllib/tuned_examples/atari-duel-ddqn.yaml +++ b/python/ray/rllib/tuned_examples/atari-duel-ddqn.yaml @@ -1,3 +1,5 @@ +# Runs on a single g3.4xl node +# See https://github.com/ray-project/rl-experiments for results dueling-ddqn: env: grid_search: @@ -27,5 +29,5 @@ dueling-ddqn: prioritized_replay_alpha: 0.5 beta_annealing_fraction: 1.0 final_prioritized_replay_beta: 1.0 - num_gpus: 1 + num_gpus: 0.2 timesteps_per_iteration: 10000 diff --git a/python/ray/rllib/tuned_examples/pong-impala-fast.yaml b/python/ray/rllib/tuned_examples/pong-impala-fast.yaml index 3466b63ea1c4a..3c29f4e0c08e4 100644 --- a/python/ray/rllib/tuned_examples/pong-impala-fast.yaml +++ b/python/ray/rllib/tuned_examples/pong-impala-fast.yaml @@ -9,7 +9,7 @@ pong-impala-fast: config: sample_batch_size: 50 train_batch_size: 1000 - num_workers: 256 + num_workers: 128 num_envs_per_worker: 5 broadcast_interval: 5 max_sample_requests_in_flight_per_worker: 1 diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 40e9635d9d8aa..86fd98af21d09 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -257,6 +257,15 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_external_env.py +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/parametric_action_cartpole.py --run=PG --stop=50 + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/parametric_action_cartpole.py --run=PPO --stop=50 + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/parametric_action_cartpole.py --run=DQN --stop=50 + docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_lstm.py