Skip to content
Permalink
Browse files

[rllib] TensorFlow 2 compatibility (#4802)

  • Loading branch information...
ericl committed May 17, 2019
1 parent 7d5ef6d commit 3807fb505b121e1e9f583f5859941cf07b1d7438
Showing with 140 additions and 1,235 deletions.
  1. +1 −1 doc/source/rllib-env.rst
  2. +1 −1 doc/source/rllib-models.rst
  3. +15 −18 python/ray/rllib/agents/ddpg/ddpg_policy_graph.py
  4. +18 −15 python/ray/rllib/agents/dqn/dqn_policy_graph.py
  5. +0 −2 python/ray/rllib/agents/impala/vtrace.py
  6. +0 −3 python/ray/rllib/agents/impala/vtrace_policy_graph.py
  7. +3 −1 python/ray/rllib/agents/impala/vtrace_test.py
  8. +0 −3 python/ray/rllib/agents/ppo/appo_policy_graph.py
  9. +3 −1 python/ray/rllib/agents/ppo/test/test.py
  10. +11 −11 python/ray/rllib/examples/batch_norm_model.py
  11. +0 −14 python/ray/rllib/examples/carla/README
  12. +0 −684 python/ray/rllib/examples/carla/env.py
  13. +0 −108 python/ray/rllib/examples/carla/models.py
  14. +0 −131 python/ray/rllib/examples/carla/scenarios.py
  15. +0 −51 python/ray/rllib/examples/carla/train_a3c.py
  16. +0 −65 python/ray/rllib/examples/carla/train_dqn.py
  17. +0 −55 python/ray/rllib/examples/carla/train_ppo.py
  18. +3 −1 python/ray/rllib/examples/custom_fast_model.py
  19. +3 −1 python/ray/rllib/examples/custom_loss.py
  20. +3 −1 python/ray/rllib/examples/export/cartpole_dqn_export.py
  21. +15 −15 python/ray/rllib/examples/multiagent_cartpole.py
  22. +11 −10 python/ray/rllib/examples/parametric_action_cartpole.py
  23. +5 −1 python/ray/rllib/models/action_dist.py
  24. +8 −10 python/ray/rllib/models/fcnet.py
  25. +1 −3 python/ray/rllib/models/lstm.py
  26. +12 −13 python/ray/rllib/models/visionnet.py
  27. +3 −3 python/ray/rllib/optimizers/aso_multi_gpu_learner.py
  28. +3 −1 python/ray/rllib/tests/test_catalog.py
  29. +5 −4 python/ray/rllib/tests/test_lstm.py
  30. +6 −5 python/ray/rllib/tests/test_nested_spaces.py
  31. +3 −1 python/ray/rllib/tests/test_optimizers.py
  32. +7 −2 python/ray/rllib/utils/__init__.py
@@ -92,7 +92,7 @@ In the above example, note that the ``env_creator`` function takes in an ``env_c
OpenAI Gym
----------

RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition <https://github.com/openai/gym/blob/master/gym/core.py>`__. You may also find the `SimpleCorridor <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__ and `Carla simulator <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/carla/env.py>`__ example env implementations useful as a reference.
RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition <https://github.com/openai/gym/blob/master/gym/core.py>`__. You may find the `SimpleCorridor <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__ example useful as a reference.

Performance
~~~~~~~~~~~
@@ -134,7 +134,7 @@ Custom TF models should subclass the common RLlib `model class <https://github.c
},
})
For a full example of a custom model in code, see the `Carla RLlib model <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/carla/models.py>`__ and associated `training scripts <https://github.com/ray-project/ray/tree/master/python/ray/rllib/examples/carla>`__. You can also reference the `unit tests <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tests/test_nested_spaces.py>`__ for Tuple and Dict spaces, which show how to access nested observation fields.
For a full example of a custom model in code, see the `custom env example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__. You can also reference the `unit tests <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tests/test_nested_spaces.py>`__ for Tuple and Dict spaces, which show how to access nested observation fields.

Custom Recurrent Models
~~~~~~~~~~~~~~~~~~~~~~~
@@ -399,8 +399,6 @@ def set_state(self, state):
self.set_pure_exploration_phase(state[2])

def _build_q_network(self, obs, obs_space, action_space, actions):
import tensorflow.contrib.layers as layers

if self.config["use_state_preprocessor"]:
q_model = ModelCatalog.get_model({
"obs": obs,
@@ -413,16 +411,12 @@ def _build_q_network(self, obs, obs_space, action_space, actions):

activation = getattr(tf.nn, self.config["critic_hidden_activation"])
for hidden in self.config["critic_hiddens"]:
q_out = layers.fully_connected(
q_out, num_outputs=hidden, activation_fn=activation)
q_values = layers.fully_connected(
q_out, num_outputs=1, activation_fn=None)
q_out = tf.layers.dense(q_out, units=hidden, activation=activation)
q_values = tf.layers.dense(q_out, units=1, activation=None)

return q_values, q_model

def _build_policy_network(self, obs, obs_space, action_space):
import tensorflow.contrib.layers as layers

if self.config["use_state_preprocessor"]:
model = ModelCatalog.get_model({
"obs": obs,
@@ -434,16 +428,19 @@ def _build_policy_network(self, obs, obs_space, action_space):
action_out = obs

activation = getattr(tf.nn, self.config["actor_hidden_activation"])
normalizer_fn = layers.layer_norm if self.config["parameter_noise"] \
else None
for hidden in self.config["actor_hiddens"]:
action_out = layers.fully_connected(
action_out,
num_outputs=hidden,
activation_fn=activation,
normalizer_fn=normalizer_fn)
action_out = layers.fully_connected(
action_out, num_outputs=self.dim_actions, activation_fn=None)
if self.config["parameter_noise"]:
import tensorflow.contrib.layers as layers
action_out = layers.fully_connected(
action_out,
num_outputs=hidden,
activation_fn=activation,
normalizer_fn=layers.layer_norm)
else:
action_out = tf.layers.dense(
action_out, units=hidden, activation=activation)
action_out = tf.layers.dense(
action_out, units=self.dim_actions, activation=None)

# Use sigmoid to scale to [0,1], but also double magnitude of input to
# emulate behaviour of tanh activation used in DDPG and TD3 papers.
@@ -507,7 +504,7 @@ def make_noisy_actions():

def make_uniform_random_actions():
# pure random exploration option
uniform_random_actions = tf.random.uniform(
uniform_random_actions = tf.random_uniform(
tf.shape(deterministic_actions))
# rescale uniform random actions according to action range
tf_range = tf.constant(action_range[None], dtype="float32")
@@ -154,8 +154,6 @@ def __init__(self,
v_max=10.0,
sigma0=0.5,
parameter_noise=False):
import tensorflow.contrib.layers as layers

self.model = model
with tf.variable_scope("action_value"):
if hiddens:
@@ -164,13 +162,18 @@ def __init__(self,
if use_noisy:
action_out = self.noisy_layer(
"hidden_%d" % i, action_out, hiddens[i], sigma0)
else:
elif parameter_noise:
import tensorflow.contrib.layers as layers
action_out = layers.fully_connected(
action_out,
num_outputs=hiddens[i],
activation_fn=tf.nn.relu,
normalizer_fn=layers.layer_norm
if parameter_noise else None)
normalizer_fn=layers.layer_norm)
else:
action_out = tf.layers.dense(
action_out,
units=hiddens[i],
activation=tf.nn.relu)
else:
# Avoid postprocessing the outputs. This enables custom models
# to be used for parametric action DQN.
@@ -183,10 +186,8 @@ def __init__(self,
sigma0,
non_linear=False)
elif hiddens:
action_scores = layers.fully_connected(
action_out,
num_outputs=num_actions * num_atoms,
activation_fn=None)
action_scores = tf.layers.dense(
action_out, units=num_actions * num_atoms, activation=None)
else:
action_scores = model.outputs
if num_atoms > 1:
@@ -214,13 +215,15 @@ def __init__(self,
state_out = self.noisy_layer("dueling_hidden_%d" % i,
state_out, hiddens[i],
sigma0)
else:
state_out = layers.fully_connected(
elif parameter_noise:
state_out = tf.contrib.layers.fully_connected(
state_out,
num_outputs=hiddens[i],
activation_fn=tf.nn.relu,
normalizer_fn=layers.layer_norm
if parameter_noise else None)
normalizer_fn=tf.contrib.layers.layer_norm)
else:
state_out = tf.layers.dense(
state_out, units=hiddens[i], activation=tf.nn.relu)
if use_noisy:
state_score = self.noisy_layer(
"dueling_output",
@@ -229,8 +232,8 @@ def __init__(self,
sigma0,
non_linear=False)
else:
state_score = layers.fully_connected(
state_out, num_outputs=num_atoms, activation_fn=None)
state_score = tf.layers.dense(
state_out, units=num_atoms, activation=None)
if num_atoms > 1:
support_logits_per_action_mean = tf.reduce_mean(
support_logits_per_action, 1)
@@ -38,8 +38,6 @@
from ray.rllib.utils import try_import_tf

tf = try_import_tf()
if tf:
nest = tf.contrib.framework.nest

VTraceFromLogitsReturns = collections.namedtuple("VTraceFromLogitsReturns", [
"vs", "pg_advantages", "log_rhos", "behaviour_action_log_probs",
@@ -278,14 +278,11 @@ def make_time_major(tensor, drop_last=False):
self.KL_stats.update({
"mean_KL_{}".format(i): tf.reduce_mean(kl),
"max_KL_{}".format(i): tf.reduce_max(kl),
"median_KL_{}".format(i): tf.contrib.distributions.
percentile(kl, 50.0),
})
else:
self.KL_stats = {
"mean_KL": tf.reduce_mean(kls[0]),
"max_KL": tf.reduce_max(kls[0]),
"median_KL": tf.contrib.distributions.percentile(kls[0], 50.0),
}

# Initialize TFPolicyGraph
@@ -26,8 +26,10 @@

from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import vtrace
from ray.rllib.utils import try_import_tf

tf = try_import_tf()


def _shaped_arange(*shape):
@@ -399,14 +399,11 @@ def make_time_major(tensor, drop_last=False):
self.KL_stats.update({
"mean_KL_{}".format(i): tf.reduce_mean(kl),
"max_KL_{}".format(i): tf.reduce_max(kl),
"median_KL_{}".format(i): tf.contrib.distributions.
percentile(kl, 50.0),
})
else:
self.KL_stats = {
"mean_KL": tf.reduce_mean(kls[0]),
"max_KL": tf.reduce_max(kls[0]),
"median_KL": tf.contrib.distributions.percentile(kls[0], 50.0),
}

# Initialize TFPolicyGraph
@@ -4,11 +4,13 @@

import unittest
import numpy as np
import tensorflow as tf
from numpy.testing import assert_allclose

from ray.rllib.models.action_dist import Categorical
from ray.rllib.agents.ppo.utils import flatten, concatenate
from ray.rllib.utils import try_import_tf

tf = try_import_tf()


# TODO(ekl): move to rllib/models dir
@@ -5,13 +5,13 @@

import argparse

import tensorflow as tf
import tensorflow.contrib.slim as slim

import ray
from ray import tune
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.models.misc import normc_initializer
from ray.rllib.utils import try_import_tf

tf = try_import_tf()

parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=200)
@@ -24,21 +24,21 @@ def _build_layers_v2(self, input_dict, num_outputs, options):
hiddens = [256, 256]
for i, size in enumerate(hiddens):
label = "fc{}".format(i)
last_layer = slim.fully_connected(
last_layer = tf.layers.dense(
last_layer,
size,
weights_initializer=normc_initializer(1.0),
activation_fn=tf.nn.tanh,
scope=label)
kernel_initializer=normc_initializer(1.0),
activation=tf.nn.tanh,
name=label)
# Add a batch norm layer
last_layer = tf.layers.batch_normalization(
last_layer, training=input_dict["is_training"])
output = slim.fully_connected(
output = tf.layers.dense(
last_layer,
num_outputs,
weights_initializer=normc_initializer(0.01),
activation_fn=None,
scope="fc_out")
kernel_initializer=normc_initializer(0.01),
activation=None,
name="fc_out")
return output, last_layer


This file was deleted.

0 comments on commit 3807fb5

Please sign in to comment.
You can’t perform that action at this time.