Skip to content

Commit

Permalink
[RLlib] Attention Net integration into ModelV2 and learning RL exampl…
Browse files Browse the repository at this point in the history
…e. (#8371)
  • Loading branch information
sven1977 committed May 18, 2020
1 parent 9347a5d commit 796a834
Show file tree
Hide file tree
Showing 44 changed files with 1,279 additions and 911 deletions.
12 changes: 6 additions & 6 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ matrix:
- os: linux
env:
- RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS=1
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHON=3.6
Expand Down Expand Up @@ -182,7 +182,7 @@ matrix:
- os: linux
env:
- RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS_TORCH=1
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHON=3.6
Expand All @@ -200,7 +200,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_QUICK_TRAIN_AND_MISC_TESTS=1
- PYTHON=3.6
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHONWARNINGS=ignore
Expand All @@ -220,7 +220,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_EXAMPLE_DIR_TESTS=1
- PYTHON=3.6
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHONWARNINGS=ignore
Expand All @@ -239,7 +239,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_A_TO_L=1
- PYTHON=3.6
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHONWARNINGS=ignore
Expand All @@ -255,7 +255,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_M_TO_Z=1
- PYTHON=3.6
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHONWARNINGS=ignore
Expand Down
2 changes: 1 addition & 1 deletion ci/travis/install-dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ install_dependencies() {
msys*) pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f "${torch_url}";;
esac

pip_packages=(scipy tensorflow=="${TF_VERSION:-2.0.0b1}" cython==0.29.0 gym \
pip_packages=(scipy tensorflow=="${TF_VERSION:-2.1.0}" cython==0.29.0 gym \
opencv-python-headless pyyaml pandas==0.24.2 requests feather-format lxml openpyxl xlrd \
py-spy pytest pytest-timeout networkx tabulate aiohttp uvicorn dataclasses pygments werkzeug \
kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio scikit-learn==0.22.2 numba \
Expand Down
4 changes: 2 additions & 2 deletions doc/source/rllib-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ For a full example of a custom model in code, see the `keras model example <http
Recurrent Models
~~~~~~~~~~~~~~~~

Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. For a RNN model it is preferred to subclass ``RecurrentTFModelV2`` to implement ``__init__()``, ``get_initial_state()``, and ``forward_rnn()``. You can check out the `custom_keras_rnn_model.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_rnn_model.py>`__ model as an example to implement your own model:
Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. For an RNN model it is preferred to subclass ``RecurrentNetwork`` to implement ``__init__()``, ``get_initial_state()``, and ``forward_rnn()``. You can check out the `custom_rnn_model.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_rnn_model.py>`__ model as an example to implement your own model:

.. autoclass:: ray.rllib.models.tf.recurrent_tf_modelv2.RecurrentTFModelV2
.. autoclass:: ray.rllib.models.tf.recurrent_net.RecurrentNetwork

.. automethod:: __init__
.. automethod:: forward_rnn
Expand Down
27 changes: 23 additions & 4 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,25 @@ py_test(
# --------------------------------------------------------------------


py_test(
name = "examples/attention_net_tf",
main = "examples/attention_net.py",
tags = ["examples", "examples_A"],
size = "large",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=80"]
)

# TODO(sven): GTrXL PyTorch.
# py_test(
# name = "examples/attention_net_torch",
# main = "examples/attention_net.py",
# tags = ["examples", "examples_A"],
# size = "large",
# srcs = ["examples/attention_net.py"],
# args = ["--as-test", "--torch", "--stop-reward=90"]
# )

py_test(
name = "examples/autoregressive_action_dist_tf",
main = "examples/autoregressive_action_dist.py",
Expand Down Expand Up @@ -1492,7 +1511,7 @@ py_test(
name = "examples/batch_norm_model_dqn_tf",
main = "examples/batch_norm_model.py",
tags = ["examples", "examples_B"],
size = "medium", # DQN learns much slower with BatchNorm.
size = "large", # DQN learns much slower with BatchNorm.
srcs = ["examples/batch_norm_model.py"],
args = ["--as-test", "--run=DQN", "--stop-reward=70"]
)
Expand All @@ -1501,7 +1520,7 @@ py_test(
name = "examples/batch_norm_model_dqn_torch",
main = "examples/batch_norm_model.py",
tags = ["examples", "examples_B"],
size = "medium", # DQN learns much slower with BatchNorm.
size = "large", # DQN learns much slower with BatchNorm.
srcs = ["examples/batch_norm_model.py"],
args = ["--as-test", "--torch", "--run=DQN", "--stop-reward=70"]
)
Expand Down Expand Up @@ -1555,7 +1574,7 @@ py_test(
name = "examples/cartpole_lstm_ppo_torch",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "small",
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)
Expand Down Expand Up @@ -1871,7 +1890,7 @@ py_test(
name = "examples/multi_agent_two_trainers_mixed_torch_tf",
main = "examples/multi_agent_two_trainers.py",
tags = ["examples", "examples_M"],
size = "small",
size = "medium",
srcs = ["examples/multi_agent_two_trainers.py"],
args = ["--as-test", "--mixed-torch-tf", "--stop-reward=70"]
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/sac/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Implementation of the Soft Actor-Critic algorithm:

[1] Soft Actor-Critic Algorithms and Applications - T. Haarnoja, A. Zhou, K. Hartikainen, et. al
[1] Soft Actor-Critic Algorithms and Applications - T. Haarnoja, A. Zhou, K. Hartikainen, et al.
https://arxiv.org/abs/1812.05905.pdf

For supporting discrete action spaces, we implemented this patch on top of the original algorithm:
Expand Down
75 changes: 75 additions & 0 deletions rllib/examples/attention_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import argparse

import ray
from ray import tune
from ray.rllib.utils import try_import_tf
from ray.rllib.models.tf.attention_net import GTrXLNet
from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune import registry

tf = try_import_tf()

parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--torch", action="store_true")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-timesteps", type=int, default=500000)
parser.add_argument("--stop-reward", type=float, default=80)

if __name__ == "__main__":
args = parser.parse_args()

assert not args.torch, "PyTorch not supported for AttentionNets yet!"

ray.init(num_cpus=args.num_cpus or None, local_mode=True)

registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
registry.register_env("RepeatInitialObsEnv",
lambda _: RepeatInitialObsEnv())
registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush()))
registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())

config = {
"env": args.env,
"env_config": {
"repeat_delay": 2,
},
"gamma": 0.99,
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": GTrXLNet,
"max_seq_len": 50,
"custom_options": {
"num_transformer_units": 1,
"attn_dim": 64,
"num_heads": 2,
"memory_tau": 50,
"head_dim": 32,
"ff_hidden_dim": 32,
},
},
"use_pytorch": args.torch,
}

stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}

results = tune.run(args.run, config=config, stop=stop, verbose=1)

if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from gym.spaces import Box, Discrete
import numpy as np

from rllib.models.tf import attention
from ray.rllib.utils import try_import_tf
from rllib.models.tf.attention_net import TrXLNet
from ray.rllib.utils.framework import try_import_tf

tf = try_import_tf()

Expand All @@ -19,16 +16,6 @@ def bit_shift_generator(seq_length, shift, batch_size):
yield seq, targets


def make_model(seq_length, num_tokens, num_layers, attn_dim, num_heads,
head_dim, ff_hidden_dim):

return tf.keras.Sequential((
attention.make_TrXL(seq_length, num_layers, attn_dim, num_heads,
head_dim, ff_hidden_dim),
tf.keras.layers.Dense(num_tokens),
))


def train_loss(targets, outputs):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=targets, logits=outputs)
Expand All @@ -39,10 +26,13 @@ def train_bit_shift(seq_length, num_iterations, print_every_n):

optimizer = tf.keras.optimizers.Adam(1e-3)

model = make_model(
seq_length,
num_tokens=2,
num_layers=1,
model = TrXLNet(
observation_space=Box(low=0, high=1, shape=(1, ), dtype=np.int32),
action_space=Discrete(2),
num_outputs=2,
model_config={"max_seq_len": seq_length},
name="trxl",
num_transformer_units=1,
attn_dim=10,
num_heads=5,
head_dim=20,
Expand All @@ -59,13 +49,20 @@ def train_bit_shift(seq_length, num_iterations, print_every_n):

@tf.function
def update_step(inputs, targets):

optimizer.minimize(lambda: train_loss(targets, model(inputs)),
model_out = model(
{
"obs": inputs
},
state=[tf.reshape(inputs, [-1, seq_length, 1])],
seq_lens=np.full(shape=(train_batch, ), fill_value=seq_length))
optimizer.minimize(lambda: train_loss(targets, model_out),
lambda: model.trainable_variables)

for i, (inputs, targets) in zip(range(num_iterations), data_gen):
inputs_in = np.reshape(inputs, [-1, 1])
targets_in = np.reshape(targets, [-1])
update_step(
tf.convert_to_tensor(inputs), tf.convert_to_tensor(targets))
tf.convert_to_tensor(inputs_in), tf.convert_to_tensor(targets_in))

if i % print_every_n == 0:
test_inputs, test_targets = next(test_gen)
Expand Down
4 changes: 2 additions & 2 deletions rllib/examples/custom_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
You can visualize experiment results in ~/ray_results using TensorBoard.
"""
import argparse
import numpy as np
import gym
from gym.spaces import Discrete, Box
import numpy as np

import ray
from ray import tune
from ray.tune import grid_search
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/custom_keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ray.rllib.agents.dqn.distributional_q_tf_model import \
DistributionalQTFModel
from ray.rllib.utils import try_import_tf
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as MyVisionNetwork
from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork

tf = try_import_tf()

Expand Down
23 changes: 23 additions & 0 deletions rllib/examples/env/debug_counter_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import gym


class DebugCounterEnv(gym.Env):
"""Simple Env that yields a ts counter as observation (0-based).
Actions have no effect.
The episode length is always 15.
Reward is always: current ts % 3.
"""

def __init__(self):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(0, 100, (1, ))
self.i = 0

def reset(self):
self.i = 0
return [self.i]

def step(self, action):
self.i += 1
return [self.i], self.i % 3, self.i >= 15, {}

0 comments on commit 796a834

Please sign in to comment.