Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
'future',
'gevent',
'gunicorn',
'gym<=0.9.5', # gym in version 0.9.6 has some temporary issues.
'munch',
'numpy',
'requests',
'scipy',
Expand Down
16 changes: 16 additions & 0 deletions tensor2tensor/bin/t2t-rl-trainer
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python
"""t2t-rl-trainer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensor2tensor.bin import t2t_rl_trainer

import tensorflow as tf

def main(argv):
t2t_rl_trainer.main(argv)


if __name__ == "__main__":
tf.app.run()
92 changes: 92 additions & 0 deletions tensor2tensor/bin/t2t_rl_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training of RL agent with PPO algorithm."""

from __future__ import absolute_import

import functools
from munch import Munch
import tensorflow as tf

from tensor2tensor.rl.collect import define_collect
from tensor2tensor.rl.envs.utils import define_batch_env
from tensor2tensor.rl.ppo import define_ppo_epoch


def define_train(policy_lambda, env_lambda, config):
env = env_lambda()
action_space = env.action_space
observation_space = env.observation_space

batch_env = define_batch_env(env_lambda, config["num_agents"])

policy_factory = tf.make_template(
'network',
functools.partial(policy_lambda, observation_space,
action_space, config))

(collect_op, memory) = define_collect(policy_factory, batch_env, config)

with tf.control_dependencies([collect_op]):
ppo_op = define_ppo_epoch(memory, policy_factory, config)

return ppo_op


def main():
train(example_params())


def train(params):
policy_lambda, env_lambda, config = params
ppo_op = define_train(policy_lambda, env_lambda, config)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(config.epochs_num):
sess.run(ppo_op)


def example_params():
from tensor2tensor.rl import networks
config = {}
config['init_mean_factor'] = 0.1
config['init_logstd'] = 0.1
config['policy_layers'] = 100, 100
config['value_layers'] = 100, 100
config['num_agents'] = 30
config['clipping_coef'] = 0.2
config['gae_gamma'] = 0.99
config['gae_lambda'] = 0.95
config['entropy_loss_coef'] = 0.01
config['value_loss_coef'] = 1
config['optimizer'] = tf.train.AdamOptimizer
config['learning_rate'] = 1e-4
config['optimization_epochs'] = 15
config['epoch_length'] = 200
config['epochs_num'] = 2000

config = Munch(config)
return networks.feed_forward_gaussian_fun, pendulum_lambda, config


def pendulum_lambda():
import gym
return gym.make("Pendulum-v0")


if __name__ == '__main__':
main()
10 changes: 10 additions & 0 deletions tensor2tensor/rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Tensor2Tensor Reinforcement Learning starter.

The rl package intention is to provide possiblity to run reinforcement
algorithms within Tensorflow's computation graph.

Currently the only supported algorithm is Proximy Policy Optimization - PPO.

## Sample usage - training in Pendulum-v0 environment.

```t2t-rl-trainer```
Empty file added tensor2tensor/rl/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions tensor2tensor/rl/collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Collect trajectories from interactions of agent with environment."""

import tensorflow as tf


def define_collect(policy_factory, batch_env, config):

memory_shape = [config.epoch_length] + [batch_env.observ.shape.as_list()[0]]
memories_shapes_and_types = [
# observation
(memory_shape + [batch_env.observ.shape.as_list()[1]], tf.float32),
(memory_shape, tf.float32), # reward
(memory_shape, tf.bool), # done
(memory_shape + batch_env.action_shape, tf.float32), # action
(memory_shape, tf.float32), # pdf
(memory_shape, tf.float32), # value function
]
memory = [tf.Variable(tf.zeros(shape, dtype), trainable=False)
for (shape, dtype) in memories_shapes_and_types]
cumulative_rewards = tf.Variable(
tf.zeros(config.num_agents, tf.float32), trainable=False)

should_reset_var = tf.Variable(True, trainable=False)
reset_op = tf.cond(should_reset_var,
lambda: batch_env.reset(tf.range(config.num_agents)),
lambda: 0.0)
with tf.control_dependencies([reset_op]):
reset_once_op = tf.assign(should_reset_var, False)

with tf.control_dependencies([reset_once_op]):

def step(index, scores_sum, scores_num):
# Note - the only way to ensure making a copy of tensor is to run simple
# operation. We are waiting for tf.copy:
# https://github.com/tensorflow/tensorflow/issues/11186
obs_copy = batch_env.observ + 0
actor_critic = policy_factory(tf.expand_dims(obs_copy, 0))
policy = actor_critic.policy
action = policy.sample()
postprocessed_action = actor_critic.action_postprocessing(action)
simulate_output = batch_env.simulate(postprocessed_action[0, ...])
pdf = policy.prob(action)[0]
with tf.control_dependencies(simulate_output):
reward, done = simulate_output
done = tf.reshape(done, (config.num_agents,))
to_save = [obs_copy, reward, done, action[0, ...], pdf,
actor_critic.value[0]]
save_ops = [tf.scatter_update(memory_slot, index, value)
for memory_slot, value in zip(memory, to_save)]
cumulate_rewards_op = cumulative_rewards.assign_add(reward)
agent_indicies_to_reset = tf.where(done)[:, 0]
with tf.control_dependencies([cumulate_rewards_op]):
scores_sum_delta = tf.reduce_sum(
tf.gather(cumulative_rewards, agent_indicies_to_reset))
scores_num_delta = tf.count_nonzero(done, dtype=tf.int32)
with tf.control_dependencies(save_ops + [scores_sum_delta,
scores_num_delta]):
reset_env_op = batch_env.reset(agent_indicies_to_reset)
reset_cumulative_rewards_op = tf.scatter_update(
cumulative_rewards, agent_indicies_to_reset,
tf.zeros(tf.shape(agent_indicies_to_reset)))
with tf.control_dependencies([reset_env_op,
reset_cumulative_rewards_op]):
return [index + 1, scores_sum + scores_sum_delta,
scores_num + scores_num_delta]

init = [tf.constant(0), tf.constant(0.0), tf.constant(0)]
index, scores_sum, scores_num = tf.while_loop(
lambda c, _1, _2: c < config.epoch_length,
step,
init,
parallel_iterations=1,
back_prop=False)
mean_score = tf.cond(tf.greater(scores_num, 0),
lambda: scores_sum / tf.cast(scores_num, tf.float32),
lambda: 0.)
printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ")
with tf.control_dependencies([printing]):
return tf.identity(index), memory
Empty file.
129 changes: 129 additions & 0 deletions tensor2tensor/rl/envs/batch_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# The code was based on Danijar Hafner's code from tf.agents:
# https://github.com/tensorflow/agents/blob/master/agents/tools/batch_env.py

"""Combine multiple environments to step them in batch."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np


class BatchEnv(object):
"""Combine multiple environments to step them in batch."""

def __init__(self, envs, blocking):
"""Combine multiple environments to step them in batch.

To step environments in parallel, environments must support a
`blocking=False` argument to their step and reset functions that makes them
return callables instead to receive the result at a later time.

Args:
envs: List of environments.
blocking: Step environments after another rather than in parallel.

Raises:
ValueError: Environments have different observation or action spaces.
"""
self._envs = envs
self._blocking = blocking
observ_space = self._envs[0].observation_space
if not all(env.observation_space == observ_space for env in self._envs):
raise ValueError('All environments must use the same observation space.')
action_space = self._envs[0].action_space
if not all(env.action_space == action_space for env in self._envs):
raise ValueError('All environments must use the same observation space.')

def __len__(self):
"""Number of combined environments."""
return len(self._envs)

def __getitem__(self, index):
"""Access an underlying environment by index."""
return self._envs[index]

def __getattr__(self, name):
"""Forward unimplemented attributes to one of the original environments.

Args:
name: Attribute that was accessed.

Returns:
Value behind the attribute name one of the wrapped environments.
"""
return getattr(self._envs[0], name)

def step(self, actions):
"""Forward a batch of actions to the wrapped environments.

Args:
actions: Batched action to apply to the environment.

Raises:
ValueError: Invalid actions.

Returns:
Batch of observations, rewards, and done flags.
"""
for index, (env, action) in enumerate(zip(self._envs, actions)):
if not env.action_space.contains(action):
message = 'Invalid action at index {}: {}'
raise ValueError(message.format(index, action))
if self._blocking:
transitions = [
env.step(action)
for env, action in zip(self._envs, actions)]
else:
transitions = [
env.step(action, blocking=False)
for env, action in zip(self._envs, actions)]
transitions = [transition() for transition in transitions]
observs, rewards, dones, infos = zip(*transitions)
observ = np.stack(observs).astype(np.float32)
reward = np.stack(rewards).astype(np.float32)
done = np.stack(dones)
info = tuple(infos)
return observ, reward, done, info

def reset(self, indices=None):
"""Reset the environment and convert the resulting observation.

Args:
indices: The batch indices of environments to reset; defaults to all.

Returns:
Batch of observations.
"""
if indices is None:
indices = np.arange(len(self._envs))
if self._blocking:
observs = [self._envs[index].reset() for index in indices]
else:
observs = [self._envs[index].reset(blocking=False) for index in indices]
observs = [observ() for observ in observs]
observ = np.stack(observs)
observ = observ.astype(np.float32)
return observ

def close(self):
"""Send close messages to the external process and join them."""
for env in self._envs:
if hasattr(env, 'close'):
env.close()
Loading