diff --git a/.travis.yml b/.travis.yml index 694915038..339a0f606 100644 --- a/.travis.yml +++ b/.travis.yml @@ -41,7 +41,6 @@ script: --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/bin/t2t_trainer_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py - --ignore=tensor2tensor/rl/rl_trainer_lib_test.py - pytest tensor2tensor/utils/registry_test.py - pytest tensor2tensor/utils/trainer_lib_test.py - pytest tensor2tensor/visualization/visualization_test.py diff --git a/tensor2tensor/data_generators/gym.py b/tensor2tensor/data_generators/gym.py index 06b5ad0f3..c50b0db6b 100644 --- a/tensor2tensor/data_generators/gym.py +++ b/tensor2tensor/data_generators/gym.py @@ -35,8 +35,6 @@ import tensorflow as tf - - flags = tf.flags FLAGS = flags.FLAGS @@ -50,6 +48,17 @@ def __init__(self, *args, **kwargs): super(GymDiscreteProblem, self).__init__(*args, **kwargs) self._env = None + def example_reading_spec(self, label_repr=None): + + data_fields = { + "inputs": tf.FixedLenFeature([210, 160, 3], tf.int64), + "inputs_prev": tf.FixedLenFeature([210, 160, 3], tf.int64), + "targets": tf.FixedLenFeature([210, 160, 3], tf.int64), + "action": tf.FixedLenFeature([1], tf.int64) + } + + return data_fields, None + @property def env_name(self): # This is the name of the Gym environment for this problem. @@ -133,7 +142,7 @@ class GymPongRandom5k(GymDiscreteProblem): @property def env_name(self): - return "Pong-v0" + return "PongNoFrameskip-v4" @property def num_actions(self): @@ -148,21 +157,30 @@ def num_steps(self): return 5000 + @registry.register_problem class GymPongTrajectoriesFromPolicy(GymDiscreteProblem): """Pong game, loaded actions.""" - def __init__(self, event_dir, *args, **kwargs): + def __init__(self, *args, **kwargs): super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs) self._env = None - self._event_dir = event_dir + self._last_policy_op = None + self._max_frame_pl = None + self._last_action = self.env.action_space.sample() + self._skip = 4 + self._skip_step = 0 + self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, + dtype=np.uint8) + + def generator(self, data_dir, tmp_dir): env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda gym.make("PongNoFrameskip-v4"), warp=False, frame_skip=4, frame_stack=False) hparams = rl.atari_base() - with tf.variable_scope("train"): + with tf.variable_scope("train", reuse=tf.AUTO_REUSE): policy_lambda = hparams.network policy_factory = tf.make_template( "network", @@ -173,14 +191,13 @@ def __init__(self, event_dir, *args, **kwargs): self._max_frame_pl, 0), 0)) policy = actor_critic.policy self._last_policy_op = policy.mode() - self._last_action = self.env.action_space.sample() - self._skip = 4 - self._skip_step = 0 - self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, - dtype=np.uint8) - self._sess = tf.Session() - model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*")) - model_saver.restore(self._sess, FLAGS.model_path) + with tf.Session() as sess: + model_saver = tf.train.Saver( + tf.global_variables(".*network_parameters.*")) + model_saver.restore(sess, FLAGS.model_path) + for item in super(GymPongTrajectoriesFromPolicy, + self).generator(data_dir, tmp_dir): + yield item # TODO(blazej0): For training of atari agents wrappers are usually used. # Below we have a hacky solution which is a workaround to be used together @@ -191,7 +208,7 @@ def get_action(self, observation=None): self._skip_step = (self._skip_step + 1) % self._skip if self._skip_step == 0: max_frame = self._obs_buffer.max(axis=0) - self._last_action = int(self._sess.run( + self._last_action = int(tf.get_default_session().run( self._last_policy_op, feed_dict={self._max_frame_pl: max_frame})[0, 0]) return self._last_action diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index 32ef49901..075840f2f 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -41,6 +41,7 @@ from tensor2tensor.models.research import aligned from tensor2tensor.models.research import attention_lm from tensor2tensor.models.research import attention_lm_moe +from tensor2tensor.models.research import basic_conv_gen from tensor2tensor.models.research import cycle_gan from tensor2tensor.models.research import gene_expression from tensor2tensor.models.research import multimodel diff --git a/tensor2tensor/models/research/basic_conv_gen.py b/tensor2tensor/models/research/basic_conv_gen.py new file mode 100644 index 000000000..c04e0b891 --- /dev/null +++ b/tensor2tensor/models/research/basic_conv_gen.py @@ -0,0 +1,65 @@ + +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Basic models for testing simple tasks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +@registry.register_model +class BasicConvGen(t2t_model.T2TModel): + + def body(self, features): + print(features) + filters = self.hparams.hidden_size + cur_frame = tf.to_float(features["inputs"]) + prev_frame = tf.to_float(features["inputs_prev"]) + print(features["inputs"].shape, cur_frame.shape, prev_frame.shape) + action = common_layers.embedding(tf.to_int64(features["action"]), + 10, filters) + action = tf.reshape(action, [-1, 1, 1, filters]) + + frames = tf.concat([cur_frame, prev_frame], axis=3) + h1 = tf.layers.conv2d(frames, filters, kernel_size=(3, 3), padding="SAME") + h2 = tf.layers.conv2d(tf.nn.relu(h1 + action), filters, + kernel_size=(5, 5), padding="SAME") + res = tf.layers.conv2d(tf.nn.relu(h2 + action), 3 * 256, + kernel_size=(3, 3), padding="SAME") + + height = tf.shape(res)[1] + width = tf.shape(res)[2] + res = tf.reshape(res, [-1, height, width, 3, 256]) + return res + + +@registry.register_hparams +def basic_conv_small(): + # """Small conv model.""" + hparams = common_hparams.basic_params1() + hparams.hidden_size = 32 + hparams.batch_size = 2 + return hparams diff --git a/tensor2tensor/notebooks/hello_t2t-rl.ipynb b/tensor2tensor/notebooks/hello_t2t-rl.ipynb new file mode 100644 index 000000000..d7e0eb6e1 --- /dev/null +++ b/tensor2tensor/notebooks/hello_t2t-rl.ipynb @@ -0,0 +1,354 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "collapsed": true, + "id": "s19ucTii_wYb" + }, + "outputs": [], + "source": [ + "# Copyright 2018 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Install deps\n", + "!pip install -q -U tensor2tensor tensorflow" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "collapsed": true, + "id": "oILRLCWN_16u" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "import collections\n", + "import sys\n", + "import tempfile\n", + "\n", + "from tensor2tensor import models\n", + "from tensor2tensor import problems\n", + "from tensor2tensor.rl import rl_trainer_lib\n", + "from tensor2tensor.utils import trainer_lib\n", + "from tensor2tensor.utils import t2t_model\n", + "from tensor2tensor.utils import registry\n", + "\n", + "# Other setup\n", + "Modes = tf.estimator.ModeKeys\n", + "\n", + "prefix = \"~/t2t_rl_data\"\n", + "# Setup data directories\n", + "data_dir = os.path.expanduser(prefix + \"/data\")\n", + "tmp_dir = os.path.expanduser(prefix + \"/tmp\")\n", + "tf.gfile.MakeDirs(data_dir)\n", + "tf.gfile.MakeDirs(tmp_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train policy\n", + "\n", + "The training of the policy will take around 1h on GPU." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Overriding hparams in atari_base with epochs_num=1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2018-03-07 00:06:25,614] Overriding hparams in atari_base with epochs_num=1\n", + "[2018-03-07 00:06:25,620] Making new env: PongNoFrameskip-v4\n", + "[2018-03-07 00:06:25,860] Making new env: PongNoFrameskip-v4\n", + "[2018-03-07 00:06:25,865] Making new env: PongNoFrameskip-v4\n", + "[2018-03-07 00:06:25,872] Making new env: PongNoFrameskip-v4\n", + "[2018-03-07 00:06:25,883] Making new env: PongNoFrameskip-v4\n", + "[2018-03-07 00:06:25,892] Making new env: PongNoFrameskip-v4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:From /home/blazej.osinski/t2t/t2t_jupyter_kernel/local/lib/python2.7/site-packages/tensorflow/python/ops/distributions/categorical.py:310: calling argmax (from tensorflow.python.ops.math_ops) with dimension is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Use the `axis` argument instead\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2018-03-07 00:06:26,589] From /home/blazej.osinski/t2t/t2t_jupyter_kernel/local/lib/python2.7/site-packages/tensorflow/python/ops/distributions/categorical.py:310: calling argmax (from tensorflow.python.ops.math_ops) with dimension is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Use the `axis` argument instead\n", + "[2018-03-07 00:06:27,589] Making new env: PongNoFrameskip-v4\n", + "[2018-03-07 00:06:27,772] Making new env: PongNoFrameskip-v4\n", + "[2018-03-07 00:06:47,971] Starting new video recorder writing to /home/blazej.osinski/t2t_rl_data/data/ppo_Yr5Rjt/openaigym.video.0.144364.video000000.mp4\n", + "[2018-03-07 00:09:36,335] Finished writing results. You can upload them to the scoreboard via gym.upload('/home/blazej.osinski/t2t_rl_data/data/ppo_Yr5Rjt')\n" + ] + } + ], + "source": [ + "iteration_num=300\n", + "hparams = trainer_lib.create_hparams(\"atari_base\", \"epochs_num={}\".format(iteration_num+1))\n", + "ppo_dir = tempfile.mkdtemp(dir=data_dir, prefix=\"ppo_\")\n", + "rl_trainer_lib.train(hparams, \"stacked_pong\", ppo_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "model_path = os.path.join(ppo_dir, \"model{}.ckpt.index\".format(iteration_num))[:-6]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generate and review frames from policy" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "sys.argv = [sys.argv[0], \"--model_path\", model_path]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "tf.reset_default_graph()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2018-03-07 00:13:20,983] Making new env: PongNoFrameskip-v4\n", + "[2018-03-07 00:13:21,221] Making new env: PongNoFrameskip-v4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Restoring parameters from /home/blazej.osinski/t2t_rl_data/data/ppo_Yr5Rjt/model0.ckpt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2018-03-07 00:13:21,586] Restoring parameters from /home/blazej.osinski/t2t_rl_data/data/ppo_Yr5Rjt/model0.ckpt\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Generated 4998 Examples\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2018-03-07 00:31:57,314] Generated 4998 Examples\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Shuffling data...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2018-03-07 00:31:57,319] Shuffling data...\n" + ] + } + ], + "source": [ + "# This step is also time consuming - takes around 30 minutes.\n", + "gym_problem = problems.problem(\"gym_pong_trajectories_from_policy\")\n", + "gym_problem.generate_data(data_dir, tmp_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Reading data files from /home/blazej.osinski/t2t_rl_data/data/gym_pong_trajectories_from_policy-train*\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2018-03-07 00:32:00,394] Reading data files from /home/blazej.osinski/t2t_rl_data/data/gym_pong_trajectories_from_policy-train*\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:partition: 0 num_data_files: 10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2018-03-07 00:32:00,399] partition: 0 num_data_files: 10\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "dataset = gym_problem.dataset(Modes.TRAIN, data_dir)\n", + "iterator = dataset.make_one_shot_iterator()\n", + "next_element = iterator.get_next()\n", + "\n", + "\n", + "fig=plt.figure(figsize=(20, 80))\n", + "columns = 10\n", + "rows = 40\n", + "\n", + "with tf.Session() as sess:\n", + " for inx in range(100):\n", + " value = sess.run(next_element)\n", + " for i in range(10): # skipping surplus frames.\n", + " value = sess.run(next_element)\n", + " fig.add_subplot(rows, columns, inx+1) \n", + " image = value[\"inputs\"].reshape([210,160,3])\n", + " plt.imshow(image[:, :, 0].astype(np.float32), cmap=plt.get_cmap('gray'))\n", + "plt.show()" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "T2T with TF Eager", + "provenance": [ + { + "file_id": "1-VScmaLkMqWiSbqgUCFWefzisSREd8l1", + "timestamp": 1512175750497 + } + ], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "t2t_kernel", + "language": "python", + "name": "t2t_kernel" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/tensor2tensor/rl/README.md b/tensor2tensor/rl/README.md index 46e40403f..ffd595911 100644 --- a/tensor2tensor/rl/README.md +++ b/tensor2tensor/rl/README.md @@ -7,16 +7,47 @@ for now and under heavy development. Currently the only supported algorithm is Proximy Policy Optimization - PPO. -## Sample usage - training in the Pendulum-v0 environment. - -```python rl/t2t_rl_trainer.py --problems=Pendulum-v0 --hparams_set continuous_action_base [--output_dir dir_location]``` - -## Sample usage - training in the PongNoFrameskip-v0 environment. - -```python tensor2tensor/rl/t2t_rl_trainer.py --problem stacked_pong --hparams_set atari_base --hparams num_agents=5 --output_dir /tmp/pong`date +%Y%m%d_%H%M%S```` - -## Sample usage - generation of a model - -```python tensor2tensor/bin/t2t-trainer --generate_data --data_dir=~/t2t_data --problems=gym_pong_trajectories_from_policy --hparams_set=base_atari --model_path [model]``` - -```python tensor2tensor/bin/t2t-datagen --data_dir=~/t2t_data --tmp_dir=~/t2t_data/tmp --problem=gym_pong_trajectories_from_policy --model_path [model]``` +# Sample usages + +## Training agent in the Pendulum-v0 environment. + +``` +python rl/t2t_rl_trainer.py \ + --problems=Pendulum-v0 \ + --hparams_set continuous_action_base \ + [--output_dir dir_location] +``` + +## Training agent in the PongNoFrameskip-v0 environment. + +``` +python tensor2tensor/rl/t2t_rl_trainer.py \ + --problem stacked_pong \ + --hparams_set atari_base \ + --hparams num_agents=5 \ + [--output_dir dir_location] +``` + +## Generation of trajectories data + +``` +python tensor2tensor/bin/t2t-datagen \ + --data_dir=~/t2t_data \ + --tmp_dir=~/t2t_data/tmp \ + --problem=gym_pong_trajectories_from_policy \ + --model_path [model] +``` + +## Training model for frames generation based on randomly played games + +``` +python tensor2tensor/bin/t2t-trainer \ + --generate_data \ + --data_dir=~/t2t_data \ + --output_dir=~/t2t_data/output \ + --problems=gym_pong_random5k \ + --model=basic_conv_gen \ + --hparams_set=basic_conv_small \ + --train_steps=1000 \ + --eval_steps=10 +``` diff --git a/tensor2tensor/rl/rl_trainer_lib_test.py b/tensor2tensor/rl/rl_trainer_lib_test.py index 0f3aa2025..461e7a0da 100644 --- a/tensor2tensor/rl/rl_trainer_lib_test.py +++ b/tensor2tensor/rl/rl_trainer_lib_test.py @@ -25,14 +25,19 @@ class TrainTest(tf.test.TestCase): + test_config = ("epochs_num=4,eval_every_epochs=3,video_during_eval=False," + "num_agents=5,optimization_epochs=5,epoch_length=50") + def test_no_crash_pendulum(self): hparams = trainer_lib.create_hparams( - "continuous_action_base", "epochs_num=11,video_during_eval=False") + "continuous_action_base", + TrainTest.test_config) rl_trainer_lib.train(hparams, "Pendulum-v0") def test_no_crash_cartpole(self): hparams = trainer_lib.create_hparams( - "discrete_action_base", "epochs_num=11,video_during_eval=False") + "discrete_action_base", + TrainTest.test_config) rl_trainer_lib.train(hparams, "CartPole-v0")