Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged

1.6.2 #771

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ matrix:
- python: "3.6"
env: TF_VERSION="1.7.*"
before_install:
- echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
- curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
# Disabled TensorFlow Serving install until bug fixed. See "Export and query"
# section below.
# - echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
# - curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
- sudo apt-get update -qq
- sudo apt-get install -qq libhdf5-dev
- sudo apt-get install -qq tensorflow-model-server
# - sudo apt-get install -qq tensorflow-model-server
install:
- pip install -q "tensorflow==$TF_VERSION"
- pip install -q .[tests]
Expand Down
5 changes: 5 additions & 0 deletions docs/cloud_tpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ See the official tutorial for [running Transformer
on Cloud TPUs](https://cloud.google.com/tpu/docs/tutorials/transformer)
for some examples and try out your own problems.

You can train an Automatic Speech Recognition (ASR) model with Transformer
on TPU by using `transformer` as `model` with `transformer_librispeech_tpu` as
`hparams_set` and `librispeech` as `problem`. See this [tutorial](tutorials/ast_with_transformer.md) for more details on training it and this
[notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/asr_transformer.ipynb) to see how the resulting model transcribes your speech to text.

Image Transformer:
* `imagetransformer` with `imagetransformer_base_tpu` (or
`imagetransformer_tiny_tpu`)
Expand Down
7 changes: 6 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ accessible and [accelerate ML
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).


## Basics
## Introduction

* [Walkthrough](walkthrough.md): Install and run.
* [IPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb): Get a hands-on experience.
* [Automatic Speech Recognition notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/asr_transformer.ipynb): Transcribe speech to text with a T2T model.

## Basics

* [Overview](overview.md): How all parts of T2T code are connected.
* [New Problem](new_problem.md): Train T2T models on your data.
* [New Model](new_model.md): Create your own T2T model.
Expand All @@ -29,6 +33,7 @@ research](https://research.googleblog.com/2017/06/accelerating-deep-learning-res
* [Training on Google Cloud ML](cloud_mlengine.md)
* [Training on Google Cloud TPUs](cloud_tpu.md)
* [Distributed Training](distributed_training.md)
# [Automatic Speech Recognition (ASR) with Transformer](tutorials/asr_with_transformer.md)

## Solving your task

Expand Down
3 changes: 3 additions & 0 deletions docs/tutorials/asr_with_transformer.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Automatic Speech Recognition (ASR) with Transformer

Check out the [Automatic Speech Recognition notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/asr_transformer.ipynb) to see how the resulting model transcribes your speech to text.

## Data set

This tutorial uses the publicly available
[Librispeech](http://www.openslr.org/12/) ASR corpus.


## Generate the dataset

To generate the dataset use `t2t-datagen`. You need to create environment
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.6.1',
version='1.6.2',
description='Tensor2Tensor',
author='Google Inc.',
author_email='no-reply@google.com',
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/all_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"tensor2tensor.data_generators.ice_parsing",
"tensor2tensor.data_generators.imagenet",
"tensor2tensor.data_generators.imdb",
"tensor2tensor.data_generators.lambada",
"tensor2tensor.data_generators.librispeech",
"tensor2tensor.data_generators.lm1b",
"tensor2tensor.data_generators.mnist",
Expand Down
49 changes: 34 additions & 15 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import video_utils

from tensor2tensor.models.research import autoencoders
from tensor2tensor.models.research import rl
from tensor2tensor.rl import collect
from tensor2tensor.rl.envs import tf_atari_wrappers as atari
Expand All @@ -42,7 +43,9 @@
flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("agent_policy_path", "", "File with model for agent")

flags.DEFINE_string("agent_policy_path", "", "File with model for agent.")
flags.DEFINE_string("autoencoder_path", "", "File with model for autoencoder.")


class GymDiscreteProblem(video_utils.VideoProblem):
Expand Down Expand Up @@ -179,6 +182,7 @@ class GymPongRandom50k(GymPongRandom5k):
def num_steps(self):
return 50000


@registry.register_problem
class GymFreewayRandom5k(GymDiscreteProblem):
"""Freeway game, random actions."""
Expand Down Expand Up @@ -209,7 +213,6 @@ def num_steps(self):
return 50000


@registry.register_problem
class GymDiscreteProblemWithAgent(GymDiscreteProblem):
"""Gym environment with discrete actions and rewards and an agent."""

Expand Down Expand Up @@ -239,7 +242,7 @@ def _setup(self):
generator_batch_env = batch_env_factory(
self.environment_spec, env_hparams, num_agents=1, xvfb=False)

with tf.variable_scope("", reuse=tf.AUTO_REUSE):
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
if FLAGS.agent_policy_path:
policy_lambda = self.collect_hparams.network
else:
Expand All @@ -252,7 +255,7 @@ def _setup(self):
create_scope_now_=True,
unique_name_="network")

with tf.variable_scope("", reuse=tf.AUTO_REUSE):
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
self.collect_hparams.epoch_length = 10
_, self.collect_trigger_op = collect.define_collect(
policy_factory, generator_batch_env, self.collect_hparams,
Expand All @@ -267,6 +270,22 @@ def restore_networks(self, sess):
tf.global_variables(".*network_parameters.*"))
model_saver.restore(sess, FLAGS.agent_policy_path)

def autoencode(self, image, sess):
with tf.Graph().as_default():
hparams = autoencoders.autoencoder_discrete_pong()
hparams.data_dir = "unused"
hparams.problem_hparams = self.get_hparams(hparams)
hparams.problem = self
model = autoencoders.AutoencoderOrderedDiscrete(
hparams, tf.estimator.ModeKeys.EVAL)
img = tf.constant(image)
img = tf.to_int32(tf.reshape(
img, [1, 1, self.frame_height, self.frame_width, self.num_channels]))
encoded = model.encode(img)
model_saver = tf.train.Saver(tf.global_variables())
model_saver.restore(sess, FLAGS.autoencoder_path)
return sess.run(encoded)

def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
self._setup()
self.debug_dump_frames_path = os.path.join(
Expand All @@ -275,17 +294,14 @@ def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
self.restore_networks(sess)
# Actions are shifted by 1 by MemoryWrapper, compensate here.
avilable_data_size = sess.run(self.avilable_data_size_op)
if avilable_data_size < 1:
sess.run(self.collect_trigger_op)
pieces_generated = 0
observ, reward, _, _ = sess.run(self.data_get_op)
while pieces_generated < self.num_steps + self.warm_up:
avilable_data_size = sess.run(self.avilable_data_size_op)
if avilable_data_size < 1:
sess.run(self.collect_trigger_op)
next_observ, next_reward, action, _ = sess.run(self.data_get_op)
observ, reward, action, _, img = sess.run(self.data_get_op)
if FLAGS.autoencoder_path:
observ = self.autoencode(img, sess)
yield {"image/encoded": [observ],
"image/format": ["png"],
"image/height": [self.frame_height],
Expand All @@ -294,7 +310,6 @@ def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
"done": [int(False)],
"reward": [int(reward) - self.min_reward]}
pieces_generated += 1
observ, reward = next_observ, next_reward


@registry.register_problem
Expand All @@ -318,20 +333,24 @@ def restore_networks(self, sess):


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnPong(GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k):
class GymSimulatedDiscreteProblemWithAgentOnPong(
GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnPong(GymDiscreteProblemWithAgent, GymPongRandom5k):
class GymDiscreteProblemWithAgentOnPong(
GymDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnFreeway(GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k):
class GymSimulatedDiscreteProblemWithAgentOnFreeway(
GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnFreeway(GymDiscreteProblemWithAgent, GymFreewayRandom5k):
class GymDiscreteProblemWithAgentOnFreeway(
GymDiscreteProblemWithAgent, GymFreewayRandom5k):
pass
Loading