diff --git a/setup.py b/setup.py index 1d3f14a94..ee0eb0d09 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,8 @@ 'future', 'gevent', 'gunicorn', + 'gym<=0.9.5', # gym in version 0.9.6 has some temporary issues. + 'munch', 'numpy', 'requests', 'scipy', diff --git a/tensor2tensor/bin/t2t-rl-trainer b/tensor2tensor/bin/t2t-rl-trainer new file mode 100644 index 000000000..06c97d2d5 --- /dev/null +++ b/tensor2tensor/bin/t2t-rl-trainer @@ -0,0 +1,16 @@ +#!/usr/bin/env python +"""t2t-rl-trainer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensor2tensor.bin import t2t_rl_trainer + +import tensorflow as tf + +def main(argv): + t2t_rl_trainer.main(argv) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensor2tensor/bin/t2t_rl_trainer.py b/tensor2tensor/bin/t2t_rl_trainer.py new file mode 100644 index 000000000..b53692ccc --- /dev/null +++ b/tensor2tensor/bin/t2t_rl_trainer.py @@ -0,0 +1,92 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training of RL agent with PPO algorithm.""" + +from __future__ import absolute_import + +import functools +from munch import Munch +import tensorflow as tf + +from tensor2tensor.rl.collect import define_collect +from tensor2tensor.rl.envs.utils import define_batch_env +from tensor2tensor.rl.ppo import define_ppo_epoch + + +def define_train(policy_lambda, env_lambda, config): + env = env_lambda() + action_space = env.action_space + observation_space = env.observation_space + + batch_env = define_batch_env(env_lambda, config["num_agents"]) + + policy_factory = tf.make_template( + 'network', + functools.partial(policy_lambda, observation_space, + action_space, config)) + + (collect_op, memory) = define_collect(policy_factory, batch_env, config) + + with tf.control_dependencies([collect_op]): + ppo_op = define_ppo_epoch(memory, policy_factory, config) + + return ppo_op + + +def main(): + train(example_params()) + + +def train(params): + policy_lambda, env_lambda, config = params + ppo_op = define_train(policy_lambda, env_lambda, config) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for _ in range(config.epochs_num): + sess.run(ppo_op) + + +def example_params(): + from tensor2tensor.rl import networks + config = {} + config['init_mean_factor'] = 0.1 + config['init_logstd'] = 0.1 + config['policy_layers'] = 100, 100 + config['value_layers'] = 100, 100 + config['num_agents'] = 30 + config['clipping_coef'] = 0.2 + config['gae_gamma'] = 0.99 + config['gae_lambda'] = 0.95 + config['entropy_loss_coef'] = 0.01 + config['value_loss_coef'] = 1 + config['optimizer'] = tf.train.AdamOptimizer + config['learning_rate'] = 1e-4 + config['optimization_epochs'] = 15 + config['epoch_length'] = 200 + config['epochs_num'] = 2000 + + config = Munch(config) + return networks.feed_forward_gaussian_fun, pendulum_lambda, config + + +def pendulum_lambda(): + import gym + return gym.make("Pendulum-v0") + + +if __name__ == '__main__': + main() diff --git a/tensor2tensor/rl/README.md b/tensor2tensor/rl/README.md new file mode 100644 index 000000000..bf21ab1ad --- /dev/null +++ b/tensor2tensor/rl/README.md @@ -0,0 +1,10 @@ +# Tensor2Tensor Reinforcement Learning starter. + +The rl package intention is to provide possiblity to run reinforcement +algorithms within Tensorflow's computation graph. + +Currently the only supported algorithm is Proximy Policy Optimization - PPO. + +## Sample usage - training in Pendulum-v0 environment. + +```t2t-rl-trainer``` diff --git a/tensor2tensor/rl/__init__.py b/tensor2tensor/rl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensor2tensor/rl/collect.py b/tensor2tensor/rl/collect.py new file mode 100644 index 000000000..dadab4d92 --- /dev/null +++ b/tensor2tensor/rl/collect.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collect trajectories from interactions of agent with environment.""" + +import tensorflow as tf + + +def define_collect(policy_factory, batch_env, config): + + memory_shape = [config.epoch_length] + [batch_env.observ.shape.as_list()[0]] + memories_shapes_and_types = [ + # observation + (memory_shape + [batch_env.observ.shape.as_list()[1]], tf.float32), + (memory_shape, tf.float32), # reward + (memory_shape, tf.bool), # done + (memory_shape + batch_env.action_shape, tf.float32), # action + (memory_shape, tf.float32), # pdf + (memory_shape, tf.float32), # value function + ] + memory = [tf.Variable(tf.zeros(shape, dtype), trainable=False) + for (shape, dtype) in memories_shapes_and_types] + cumulative_rewards = tf.Variable( + tf.zeros(config.num_agents, tf.float32), trainable=False) + + should_reset_var = tf.Variable(True, trainable=False) + reset_op = tf.cond(should_reset_var, + lambda: batch_env.reset(tf.range(config.num_agents)), + lambda: 0.0) + with tf.control_dependencies([reset_op]): + reset_once_op = tf.assign(should_reset_var, False) + + with tf.control_dependencies([reset_once_op]): + + def step(index, scores_sum, scores_num): + # Note - the only way to ensure making a copy of tensor is to run simple + # operation. We are waiting for tf.copy: + # https://github.com/tensorflow/tensorflow/issues/11186 + obs_copy = batch_env.observ + 0 + actor_critic = policy_factory(tf.expand_dims(obs_copy, 0)) + policy = actor_critic.policy + action = policy.sample() + postprocessed_action = actor_critic.action_postprocessing(action) + simulate_output = batch_env.simulate(postprocessed_action[0, ...]) + pdf = policy.prob(action)[0] + with tf.control_dependencies(simulate_output): + reward, done = simulate_output + done = tf.reshape(done, (config.num_agents,)) + to_save = [obs_copy, reward, done, action[0, ...], pdf, + actor_critic.value[0]] + save_ops = [tf.scatter_update(memory_slot, index, value) + for memory_slot, value in zip(memory, to_save)] + cumulate_rewards_op = cumulative_rewards.assign_add(reward) + agent_indicies_to_reset = tf.where(done)[:, 0] + with tf.control_dependencies([cumulate_rewards_op]): + scores_sum_delta = tf.reduce_sum( + tf.gather(cumulative_rewards, agent_indicies_to_reset)) + scores_num_delta = tf.count_nonzero(done, dtype=tf.int32) + with tf.control_dependencies(save_ops + [scores_sum_delta, + scores_num_delta]): + reset_env_op = batch_env.reset(agent_indicies_to_reset) + reset_cumulative_rewards_op = tf.scatter_update( + cumulative_rewards, agent_indicies_to_reset, + tf.zeros(tf.shape(agent_indicies_to_reset))) + with tf.control_dependencies([reset_env_op, + reset_cumulative_rewards_op]): + return [index + 1, scores_sum + scores_sum_delta, + scores_num + scores_num_delta] + + init = [tf.constant(0), tf.constant(0.0), tf.constant(0)] + index, scores_sum, scores_num = tf.while_loop( + lambda c, _1, _2: c < config.epoch_length, + step, + init, + parallel_iterations=1, + back_prop=False) + mean_score = tf.cond(tf.greater(scores_num, 0), + lambda: scores_sum / tf.cast(scores_num, tf.float32), + lambda: 0.) + printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ") + with tf.control_dependencies([printing]): + return tf.identity(index), memory diff --git a/tensor2tensor/rl/envs/__init__.py b/tensor2tensor/rl/envs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensor2tensor/rl/envs/batch_env.py b/tensor2tensor/rl/envs/batch_env.py new file mode 100644 index 000000000..30bfdce55 --- /dev/null +++ b/tensor2tensor/rl/envs/batch_env.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The code was based on Danijar Hafner's code from tf.agents: +# https://github.com/tensorflow/agents/blob/master/agents/tools/batch_env.py + +"""Combine multiple environments to step them in batch.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +class BatchEnv(object): + """Combine multiple environments to step them in batch.""" + + def __init__(self, envs, blocking): + """Combine multiple environments to step them in batch. + + To step environments in parallel, environments must support a + `blocking=False` argument to their step and reset functions that makes them + return callables instead to receive the result at a later time. + + Args: + envs: List of environments. + blocking: Step environments after another rather than in parallel. + + Raises: + ValueError: Environments have different observation or action spaces. + """ + self._envs = envs + self._blocking = blocking + observ_space = self._envs[0].observation_space + if not all(env.observation_space == observ_space for env in self._envs): + raise ValueError('All environments must use the same observation space.') + action_space = self._envs[0].action_space + if not all(env.action_space == action_space for env in self._envs): + raise ValueError('All environments must use the same observation space.') + + def __len__(self): + """Number of combined environments.""" + return len(self._envs) + + def __getitem__(self, index): + """Access an underlying environment by index.""" + return self._envs[index] + + def __getattr__(self, name): + """Forward unimplemented attributes to one of the original environments. + + Args: + name: Attribute that was accessed. + + Returns: + Value behind the attribute name one of the wrapped environments. + """ + return getattr(self._envs[0], name) + + def step(self, actions): + """Forward a batch of actions to the wrapped environments. + + Args: + actions: Batched action to apply to the environment. + + Raises: + ValueError: Invalid actions. + + Returns: + Batch of observations, rewards, and done flags. + """ + for index, (env, action) in enumerate(zip(self._envs, actions)): + if not env.action_space.contains(action): + message = 'Invalid action at index {}: {}' + raise ValueError(message.format(index, action)) + if self._blocking: + transitions = [ + env.step(action) + for env, action in zip(self._envs, actions)] + else: + transitions = [ + env.step(action, blocking=False) + for env, action in zip(self._envs, actions)] + transitions = [transition() for transition in transitions] + observs, rewards, dones, infos = zip(*transitions) + observ = np.stack(observs).astype(np.float32) + reward = np.stack(rewards).astype(np.float32) + done = np.stack(dones) + info = tuple(infos) + return observ, reward, done, info + + def reset(self, indices=None): + """Reset the environment and convert the resulting observation. + + Args: + indices: The batch indices of environments to reset; defaults to all. + + Returns: + Batch of observations. + """ + if indices is None: + indices = np.arange(len(self._envs)) + if self._blocking: + observs = [self._envs[index].reset() for index in indices] + else: + observs = [self._envs[index].reset(blocking=False) for index in indices] + observs = [observ() for observ in observs] + observ = np.stack(observs) + observ = observ.astype(np.float32) + return observ + + def close(self): + """Send close messages to the external process and join them.""" + for env in self._envs: + if hasattr(env, 'close'): + env.close() diff --git a/tensor2tensor/rl/envs/in_graph_batch_env.py b/tensor2tensor/rl/envs/in_graph_batch_env.py new file mode 100644 index 000000000..d0e1e4c26 --- /dev/null +++ b/tensor2tensor/rl/envs/in_graph_batch_env.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The code was based on Danijar Hafner's code from tf.agents: +# https://github.com/tensorflow/agents/blob/master/agents/tools/in_graph_batch_env.py + +"""Batch of environments inside the TensorFlow graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import tensorflow as tf + + +class InGraphBatchEnv(object): + """Batch of environments inside the TensorFlow graph. + + The batch of environments will be stepped and reset inside of the graph using + a tf.py_func(). The current batch of observations, actions, rewards, and done + flags are held in according variables. + """ + + def __init__(self, batch_env): + """Batch of environments inside the TensorFlow graph. + + Args: + batch_env: Batch environment. + """ + self._batch_env = batch_env + observ_shape = self._parse_shape(self._batch_env.observation_space) + observ_dtype = self._parse_dtype(self._batch_env.observation_space) + self.action_shape = list(self._parse_shape(self._batch_env.action_space)) + self.action_dtype = self._parse_dtype(self._batch_env.action_space) + with tf.variable_scope('env_temporary'): + self._observ = tf.Variable( + tf.zeros((len(self._batch_env),) + observ_shape, observ_dtype), + name='observ', trainable=False) + + def __getattr__(self, name): + """Forward unimplemented attributes to one of the original environments. + + Args: + name: Attribute that was accessed. + + Returns: + Value behind the attribute name in one of the original environments. + """ + return getattr(self._batch_env, name) + + def __len__(self): + """Number of combined environments.""" + return len(self._batch_env) + + def __getitem__(self, index): + """Access an underlying environment by index.""" + return self._batch_env[index] + + def simulate(self, action): + """Step the batch of environments. + + The results of the step can be accessed from the variables defined below. + + Args: + action: Tensor holding the batch of actions to apply. + + Returns: + Operation. + """ + with tf.name_scope('environment/simulate'): + if action.dtype in (tf.float16, tf.float32, tf.float64): + action = tf.check_numerics(action, 'action') + observ_dtype = self._parse_dtype(self._batch_env.observation_space) + observ, reward, done = tf.py_func( + lambda a: self._batch_env.step(a)[:3], [action], + [observ_dtype, tf.float32, tf.bool], name='step') + observ = tf.check_numerics(observ, 'observ') + reward = tf.check_numerics(reward, 'reward') + with tf.control_dependencies([self._observ.assign(observ)]): + return tf.identity(reward), tf.identity(done) + + + def reset(self, indices=None): + """Reset the batch of environments. + + Args: + indices: The batch indices of the environments to reset. + + Returns: + Batch tensor of the new observations. + """ + return tf.cond( + tf.cast(tf.shape(indices)[0], tf.bool), + lambda: self._reset_non_empty(indices), lambda: 0.0) + + def _reset_non_empty(self, indices): + """Reset the batch of environments. + + Args: + indices: The batch indices of the environments to reset; defaults to all. + + Returns: + Batch tensor of the new observations. + """ + observ_dtype = self._parse_dtype(self._batch_env.observation_space) + observ = tf.py_func( + self._batch_env.reset, [indices], observ_dtype, name='reset') + observ = tf.check_numerics(observ, 'observ') + with tf.control_dependencies([ + tf.scatter_update(self._observ, indices, observ)]): + return tf.identity(observ) + + @property + def observ(self): + """Access the variable holding the current observation.""" + return self._observ + + def close(self): + """Send close messages to the external process and join them.""" + self._batch_env.close() + + def _parse_shape(self, space): + """Get a tensor shape from a OpenAI Gym space. + + Args: + space: Gym space. + + Returns: + Shape tuple. + """ + if isinstance(space, gym.spaces.Discrete): + return () + if isinstance(space, gym.spaces.Box): + return space.shape + raise NotImplementedError() + + def _parse_dtype(self, space): + """Get a tensor dtype from a OpenAI Gym space. + + Args: + space: Gym space. + + Returns: + TensorFlow data type. + """ + if isinstance(space, gym.spaces.Discrete): + return tf.int32 + if isinstance(space, gym.spaces.Box): + return tf.float32 + raise NotImplementedError() diff --git a/tensor2tensor/rl/envs/utils.py b/tensor2tensor/rl/envs/utils.py new file mode 100644 index 000000000..2b81af270 --- /dev/null +++ b/tensor2tensor/rl/envs/utils.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The code was based on Danijar Hafner's code from tf.agents: +# https://github.com/tensorflow/agents/blob/master/agents/tools/wrappers.py +# https://github.com/tensorflow/agents/blob/master/agents/scripts/utility.py + +"""Utilities for using batched environments.""" + +import atexit +import multiprocessing +import sys +import traceback +import tensorflow as tf + +from tensor2tensor.rl.envs import batch_env +from tensor2tensor.rl.envs import in_graph_batch_env + +class ExternalProcessEnv(object): + """Step environment in a separate process for lock free paralellism.""" + + # Message types for communication via the pipe. + _ACCESS = 1 + _CALL = 2 + _RESULT = 3 + _EXCEPTION = 4 + _CLOSE = 5 + + def __init__(self, constructor): + """Step environment in a separate process for lock free paralellism. + + The environment will be created in the external process by calling the + specified callable. This can be an environment class, or a function + creating the environment and potentially wrapping it. The returned + environment should not access global variables. + + Args: + constructor: Callable that creates and returns an OpenAI gym environment. + + Attributes: + observation_space: The cached observation space of the environment. + action_space: The cached action space of the environment. + """ + self._conn, conn = multiprocessing.Pipe() + self._process = multiprocessing.Process( + target=self._worker, args=(constructor, conn)) + atexit.register(self.close) + self._process.start() + self._observ_space = None + self._action_space = None + + @property + def observation_space(self): + if not self._observ_space: + self._observ_space = self.__getattr__('observation_space') + return self._observ_space + + @property + def action_space(self): + if not self._action_space: + self._action_space = self.__getattr__('action_space') + return self._action_space + + def __getattr__(self, name): + """Request an attribute from the environment. + + Note that this involves communication with the external process, so it can + be slow. + + Args: + name: Attribute to access. + + Returns: + Value of the attribute. + """ + self._conn.send((self._ACCESS, name)) + return self._receive() + + def call(self, name, *args, **kwargs): + """Asynchronously call a method of the external environment. + + Args: + name: Name of the method to call. + *args: Positional arguments to forward to the method. + **kwargs: Keyword arguments to forward to the method. + + Returns: + Promise object that blocks and provides the return value when called. + """ + payload = name, args, kwargs + self._conn.send((self._CALL, payload)) + return self._receive + + def close(self): + """Send a close message to the external process and join it.""" + try: + self._conn.send((self._CLOSE, None)) + self._conn.close() + except IOError: + # The connection was already closed. + pass + self._process.join() + + def step(self, action, blocking=True): + """Step the environment. + + Args: + action: The action to apply to the environment. + blocking: Whether to wait for the result. + + Returns: + Transition tuple when blocking, otherwise callable that returns the + transition tuple. + """ + promise = self.call('step', action) + if blocking: + return promise() + else: + return promise + + def reset(self, blocking=True): + """Reset the environment. + + Args: + blocking: Whether to wait for the result. + + Returns: + New observation when blocking, otherwise callable that returns the new + observation. + """ + promise = self.call('reset') + if blocking: + return promise() + else: + return promise + + def _receive(self): + """Wait for a message from the worker process and return its payload. + + Raises: + Exception: An exception was raised inside the worker process. + KeyError: The reveived message is of an unknown type. + + Returns: + Payload object of the message. + """ + message, payload = self._conn.recv() + # Re-raise exceptions in the main process. + if message == self._EXCEPTION: + stacktrace = payload + raise Exception(stacktrace) + if message == self._RESULT: + return payload + raise KeyError('Received message of unexpected type {}'.format(message)) + + def _worker(self, constructor, conn): + """The process waits for actions and sends back environment results. + + Args: + constructor: Constructor for the OpenAI Gym environment. + conn: Connection for communication to the main process. + """ + try: + env = constructor() + while True: + try: + # Only block for short times to have keyboard exceptions be raised. + if not conn.poll(0.1): + continue + message, payload = conn.recv() + except (EOFError, KeyboardInterrupt): + break + if message == self._ACCESS: + name = payload + result = getattr(env, name) + conn.send((self._RESULT, result)) + continue + if message == self._CALL: + name, args, kwargs = payload + result = getattr(env, name)(*args, **kwargs) + conn.send((self._RESULT, result)) + continue + if message == self._CLOSE: + assert payload is None + break + raise KeyError('Received message of unknown type {}'.format(message)) + except Exception: # pylint: disable=broad-except + stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) + tf.logging.error('Error in environment process: {}'.format(stacktrace)) + conn.send((self._EXCEPTION, stacktrace)) + conn.close() + +def define_batch_env(constructor, num_agents, env_processes=True): + """Create environments and apply all desired wrappers. + + Args: + constructor: Constructor of an OpenAI gym environment. + num_agents: Number of environments to combine in the batch. + env_processes: Whether to step environment in external processes. + + Returns: + In-graph environments object. + """ + with tf.variable_scope('environments'): + if env_processes: + envs = [ + ExternalProcessEnv(constructor) + for _ in range(num_agents)] + else: + envs = [constructor() for _ in range(num_agents)] + env = batch_env.BatchEnv(envs, blocking=not env_processes) + env = in_graph_batch_env.InGraphBatchEnv(env) + return env diff --git a/tensor2tensor/rl/networks.py b/tensor2tensor/rl/networks.py new file mode 100644 index 000000000..af8709191 --- /dev/null +++ b/tensor2tensor/rl/networks.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Neural networks for actor-critic algorithms.""" + +import operator +import functools +import collections +import tensorflow as tf +import gym + + +NetworkOutput = collections.namedtuple( + 'NetworkOutput', 'policy, value, action_postprocessing') + + +def feed_forward_gaussian_fun(observation_space, action_space, config, + observations): + assert isinstance(observation_space, gym.spaces.box.Box) + + mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer( + factor=config.init_mean_factor) + logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10) + + flat_observations = tf.reshape(observations, [ + tf.shape(observations)[0], tf.shape(observations)[1], + functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)]) + + with tf.variable_scope('policy'): + x = flat_observations + for size in config.policy_layers: + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) + mean = tf.contrib.layers.fully_connected( + x, action_space.shape[0], tf.tanh, + weights_initializer=mean_weights_initializer) + logstd = tf.get_variable( + 'logstd', mean.shape[2:], tf.float32, logstd_initializer) + logstd = tf.tile( + logstd[None, None], + [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2)) + with tf.variable_scope('value'): + x = flat_observations + for size in config.value_layers: + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) + value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0] + mean = tf.check_numerics(mean, 'mean') + logstd = tf.check_numerics(logstd, 'logstd') + value = tf.check_numerics(value, 'value') + + policy = tf.contrib.distributions.MultivariateNormalDiag(mean, + tf.exp(logstd)) + + return NetworkOutput(policy, value, lambda a: tf.clip_by_value(a, -2., 2)) diff --git a/tensor2tensor/rl/ppo.py b/tensor2tensor/rl/ppo.py new file mode 100644 index 000000000..1c9654608 --- /dev/null +++ b/tensor2tensor/rl/ppo.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PPO algorithm implementation. + +Based on: https://arxiv.org/abs/1707.06347 +""" + +import tensorflow as tf + +def define_ppo_step(observation, action, reward, done, value, old_pdf, + policy_factory, config): + + new_policy_dist, new_value, _ = policy_factory(observation) + new_pdf = new_policy_dist.prob(action) + + ratio = new_pdf/old_pdf + clipped_ratio = tf.clip_by_value(ratio, 1 - config.clipping_coef, + 1 + config.clipping_coef) + + advantage = calculate_discounted_return( + reward, value, done, config.gae_gamma, config.gae_lambda) - value + + advantage_mean, advantage_variance = tf.nn.moments(advantage, axes=[0, 1], + keep_dims=True) + advantage_normalized = tf.stop_gradient( + (advantage - advantage_mean)/(tf.sqrt(advantage_variance) + 1e-8)) + + surrogate_objective = tf.minimum(clipped_ratio * advantage_normalized, + ratio * advantage_normalized) + policy_loss = -tf.reduce_mean(surrogate_objective) + + value_error = calculate_discounted_return( + reward, new_value, done, config.gae_gamma, config.gae_lambda) - value + value_loss = config.value_loss_coef * tf.reduce_mean(value_error ** 2) + + entropy = new_policy_dist.entropy() + entropy_loss = -config.entropy_loss_coef * tf.reduce_mean(entropy) + + total_loss = policy_loss + value_loss + entropy_loss + + optimization_op = config.optimizer(config.learning_rate).minimize(total_loss) + + with tf.control_dependencies([optimization_op]): + return [tf.identity(x) for x in (policy_loss, value_loss, entropy_loss)] + + +def define_ppo_epoch(memory, policy_factory, config): + observation, reward, done, action, old_pdf, value = memory + + # This is to avoid propagating gradients though simulation of simulation + observation = tf.stop_gradient(observation) + action = tf.stop_gradient(action) + reward = tf.stop_gradient(reward) + done = tf.stop_gradient(done) + value = tf.stop_gradient(value) + old_pdf = tf.stop_gradient(old_pdf) + + policy_loss, value_loss, entropy_loss = tf.scan( + lambda _1, _2: define_ppo_step(observation, action, reward, done, value, + old_pdf, policy_factory, config), + tf.range(config.optimization_epochs), + [0., 0., 0.], + parallel_iterations=1) + + print_losses = tf.group( + tf.Print(0, [tf.reduce_mean(policy_loss)], 'policy loss: '), + tf.Print(0, [tf.reduce_mean(value_loss)], 'value loss: '), + tf.Print(0, [tf.reduce_mean(entropy_loss)], 'entropy loss: ')) + + return print_losses + + +def calculate_discounted_return(reward, value, done, discount, unused_lambda): + """Discounted Monte-Carlo returns.""" + done = tf.cast(done, tf.float32) + reward2 = done[-1, :] * reward[-1, :] + (1 - done[-1, :]) * value[-1, :] + reward = tf.concat([reward[:-1,], reward2[None, ...]], axis=0) + return_ = tf.reverse(tf.scan( + lambda agg, cur: cur[0] + (1 - cur[1]) * discount * agg, # fn + [tf.reverse(reward, [0]), # elem + tf.reverse(done, [0])], + tf.zeros_like(reward[0, :]), # initializer + 1, + False), [0]) + return tf.check_numerics(return_, 'return') diff --git a/tensor2tensor/rl/train_test.py b/tensor2tensor/rl/train_test.py new file mode 100644 index 000000000..ac14c2083 --- /dev/null +++ b/tensor2tensor/rl/train_test.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests of basic flow of collecting trajectories and training PPO.""" + +import tensorflow as tf + +from tensor2tensor.bin import t2t_rl_trainer + + +FLAGS = tf.app.flags.FLAGS + + +class TrainTest(tf.test.TestCase): + + def test_no_crash_pendulum(self): + params = t2t_rl_trainer.example_params() + params[2].epochs_num = 10 + t2t_rl_trainer.train(params) + + +if __name__ == '__main__': + FLAGS.config = 'unused' + tf.test.main()