Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ script:
--ignore=tensor2tensor/problems_test.py
--ignore=tensor2tensor/bin/t2t_trainer_test.py
--ignore=tensor2tensor/data_generators/algorithmic_math_test.py
--ignore=tensor2tensor/rl/rl_trainer_lib_test.py
- pytest tensor2tensor/utils/registry_test.py
- pytest tensor2tensor/utils/trainer_lib_test.py
- pytest tensor2tensor/visualization/visualization_test.py
Expand Down
47 changes: 32 additions & 15 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
import tensorflow as tf




flags = tf.flags
FLAGS = flags.FLAGS

Expand All @@ -50,6 +48,17 @@ def __init__(self, *args, **kwargs):
super(GymDiscreteProblem, self).__init__(*args, **kwargs)
self._env = None

def example_reading_spec(self, label_repr=None):

data_fields = {
"inputs": tf.FixedLenFeature([210, 160, 3], tf.int64),
"inputs_prev": tf.FixedLenFeature([210, 160, 3], tf.int64),
"targets": tf.FixedLenFeature([210, 160, 3], tf.int64),
"action": tf.FixedLenFeature([1], tf.int64)
}

return data_fields, None

@property
def env_name(self):
# This is the name of the Gym environment for this problem.
Expand Down Expand Up @@ -133,7 +142,7 @@ class GymPongRandom5k(GymDiscreteProblem):

@property
def env_name(self):
return "Pong-v0"
return "PongNoFrameskip-v4"

@property
def num_actions(self):
Expand All @@ -148,21 +157,30 @@ def num_steps(self):
return 5000



@registry.register_problem
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
"""Pong game, loaded actions."""

def __init__(self, event_dir, *args, **kwargs):
def __init__(self, *args, **kwargs):
super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
self._env = None
self._event_dir = event_dir
self._last_policy_op = None
self._max_frame_pl = None
self._last_action = self.env.action_space.sample()
self._skip = 4
self._skip_step = 0
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
dtype=np.uint8)

def generator(self, data_dir, tmp_dir):
env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda
gym.make("PongNoFrameskip-v4"),
warp=False,
frame_skip=4,
frame_stack=False)
hparams = rl.atari_base()
with tf.variable_scope("train"):
with tf.variable_scope("train", reuse=tf.AUTO_REUSE):
policy_lambda = hparams.network
policy_factory = tf.make_template(
"network",
Expand All @@ -173,14 +191,13 @@ def __init__(self, event_dir, *args, **kwargs):
self._max_frame_pl, 0), 0))
policy = actor_critic.policy
self._last_policy_op = policy.mode()
self._last_action = self.env.action_space.sample()
self._skip = 4
self._skip_step = 0
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
dtype=np.uint8)
self._sess = tf.Session()
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
model_saver.restore(self._sess, FLAGS.model_path)
with tf.Session() as sess:
model_saver = tf.train.Saver(
tf.global_variables(".*network_parameters.*"))
model_saver.restore(sess, FLAGS.model_path)
for item in super(GymPongTrajectoriesFromPolicy,
self).generator(data_dir, tmp_dir):
yield item

# TODO(blazej0): For training of atari agents wrappers are usually used.
# Below we have a hacky solution which is a workaround to be used together
Expand All @@ -191,7 +208,7 @@ def get_action(self, observation=None):
self._skip_step = (self._skip_step + 1) % self._skip
if self._skip_step == 0:
max_frame = self._obs_buffer.max(axis=0)
self._last_action = int(self._sess.run(
self._last_action = int(tf.get_default_session().run(
self._last_policy_op,
feed_dict={self._max_frame_pl: max_frame})[0, 0])
return self._last_action
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from tensor2tensor.models.research import aligned
from tensor2tensor.models.research import attention_lm
from tensor2tensor.models.research import attention_lm_moe
from tensor2tensor.models.research import basic_conv_gen
from tensor2tensor.models.research import cycle_gan
from tensor2tensor.models.research import gene_expression
from tensor2tensor.models.research import multimodel
Expand Down
65 changes: 65 additions & 0 deletions tensor2tensor/models/research/basic_conv_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic models for testing simple tasks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

import tensorflow as tf


@registry.register_model
class BasicConvGen(t2t_model.T2TModel):

def body(self, features):
print(features)
filters = self.hparams.hidden_size
cur_frame = tf.to_float(features["inputs"])
prev_frame = tf.to_float(features["inputs_prev"])
print(features["inputs"].shape, cur_frame.shape, prev_frame.shape)
action = common_layers.embedding(tf.to_int64(features["action"]),
10, filters)
action = tf.reshape(action, [-1, 1, 1, filters])

frames = tf.concat([cur_frame, prev_frame], axis=3)
h1 = tf.layers.conv2d(frames, filters, kernel_size=(3, 3), padding="SAME")
h2 = tf.layers.conv2d(tf.nn.relu(h1 + action), filters,
kernel_size=(5, 5), padding="SAME")
res = tf.layers.conv2d(tf.nn.relu(h2 + action), 3 * 256,
kernel_size=(3, 3), padding="SAME")

height = tf.shape(res)[1]
width = tf.shape(res)[2]
res = tf.reshape(res, [-1, height, width, 3, 256])
return res


@registry.register_hparams
def basic_conv_small():
# """Small conv model."""
hparams = common_hparams.basic_params1()
hparams.hidden_size = 32
hparams.batch_size = 2
return hparams
Loading