Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c2e4d14
Add a faster version of SRU and some hparams settings for autoencoders.
Apr 27, 2018
e07e969
Correct typo.
Apr 27, 2018
5324d11
local mixture-of-experts running efficiently on TPU
nshazeer Apr 27, 2018
0e2334e
Adding functions for calculating position and step embedding as timin…
a-googler Apr 27, 2018
3326e0b
Implement factored accumulators for >=3-dimensional varaibles in
nshazeer Apr 27, 2018
6a88e14
Add tests for TF 1.8
Apr 30, 2018
faf2178
Additional hparam for setting kernel_size in conv_relu_conv for the f…
a-googler Apr 30, 2018
dd2d606
Change to local_moe_tpu - make better choices for second-place expert…
nshazeer Apr 30, 2018
d01c325
Bypass transformer target decoder if decode_autoregressive is set to …
May 1, 2018
1c11e51
Close the environment in ExternalProcessEnv.
a-googler May 1, 2018
d957bdb
Fix if condition in squad_concat data generator.
a-googler May 1, 2018
fd391e9
internal
royaurko May 1, 2018
de60a75
Using tf.contrib.summary instead of tf.summary in r-transformer.
a-googler May 1, 2018
18c0917
Some fixes for wikisum data generation
May 2, 2018
32ee97d
internal
royaurko May 3, 2018
9c476bd
Corrections to RL and autoencoder code.
May 3, 2018
b9088c0
Adding subject-verb agreement dataset to t2t.
a-googler May 3, 2018
63bb020
Sort some dictionaries before using them to make the XLA program fing…
a-googler May 3, 2018
6972e09
Adding wikitext103 dataset to T2T.
a-googler May 3, 2018
0bd8793
Add imagenet small problem
May 3, 2018
119b756
Cost updates to Wikisum and further improvements
May 3, 2018
e466fd9
new hparam configs for tpu runs
May 3, 2018
89ff940
Fix error in t2t eager colab
May 3, 2018
b32062a
Use tf.contrib.eager.in_eager_mode()
May 3, 2018
c98fab4
Correct data sharding in video problems and add L1 loss for video pre…
May 4, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ env:
- TF_VERSION="1.5.*"
- TF_VERSION="1.6.*"
- TF_VERSION="1.7.*"
- TF_VERSION="1.8.*"
matrix:
exclude:
- python: "3.6"
env: TF_VERSION="1.5.*"
- python: "3.6"
env: TF_VERSION="1.6.*"
- python: "3.6"
env: TF_VERSION="1.7.*"
before_install:
- echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
- curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ You can chat with us on

### Quick Start

[This iPython notebook](https://goo.gl/wkHexj) explains T2T and runs in your
browser using a free VM from Google, no installation needed.
Alternatively, here is a one-command version that installs T2T, downloads MNIST,
trains a model and evaluates it:
[This iPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb)
explains T2T and runs in your browser using a free VM from Google,
no installation needed. Alternatively, here is a one-command version that
installs T2T, downloads MNIST, trains a model and evaluates it:

```
pip install tensor2tensor && t2t-trainer \
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ research](https://research.googleblog.com/2017/06/accelerating-deep-learning-res
## Basics

* [Walkthrough](walkthrough.md): Install and run.
* [IPython notebook](https://goo.gl/wkHexj): Get a hands-on experience.
* [IPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb): Get a hands-on experience.
* [Overview](overview.md): How all parts of T2T code are connected.
* [New Problem](new_problem.md): Train T2T models on your data.
* [New Model](new_model.md): Create your own T2T model.
Expand Down
8 changes: 4 additions & 4 deletions docs/walkthrough.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ You can chat with us on

### Quick Start

[This iPython notebook](https://goo.gl/wkHexj) explains T2T and runs in your
browser using a free VM from Google, no installation needed.
Alternatively, here is a one-command version that installs T2T, downloads MNIST,
trains a model and evaluates it:
[This iPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb)
explains T2T and runs in your browser using a free VM from Google,
no installation needed. Alternatively, here is a one-command version that
installs T2T, downloads MNIST, trains a model and evaluates it:

```
pip install tensor2tensor && t2t-trainer \
Expand Down
2 changes: 2 additions & 0 deletions tensor2tensor/data_generators/all_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"tensor2tensor.data_generators.ptb",
"tensor2tensor.data_generators.snli",
"tensor2tensor.data_generators.squad",
"tensor2tensor.data_generators.subject_verb_agreement",
"tensor2tensor.data_generators.translate_encs",
"tensor2tensor.data_generators.translate_ende",
"tensor2tensor.data_generators.translate_enet",
Expand All @@ -56,6 +57,7 @@
"tensor2tensor.data_generators.twentybn",
"tensor2tensor.data_generators.wiki",
"tensor2tensor.data_generators.wikisum.wikisum",
"tensor2tensor.data_generators.wikitext103",
"tensor2tensor.data_generators.wsj_parsing",
]

Expand Down
11 changes: 8 additions & 3 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def to_example(dictionary):
features = {}
for (k, v) in six.iteritems(dictionary):
if not v:
raise ValueError("Empty generated field: %s", str((k, v)))
raise ValueError("Empty generated field: %s" % str((k, v)))
if isinstance(v[0], six.integer_types):
features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
elif isinstance(v[0], float):
Expand Down Expand Up @@ -130,7 +130,8 @@ def outputs_exist(filenames):
return out_fname


def generate_files(generator, output_filenames, max_cases=None):
def generate_files(generator, output_filenames,
max_cases=None, cycle_every_n=1):
"""Generate cases from a generator and save as TFRecord files.

Generated cases are transformed to tf.Example protos and saved as TFRecords
Expand All @@ -141,6 +142,8 @@ def generate_files(generator, output_filenames, max_cases=None):
output_filenames: List of output file paths.
max_cases: maximum number of cases to get from the generator;
if None (default), we use the generator until StopIteration is raised.
cycle_every_n: how many cases from the generator to take before
switching to the next shard; by default set to 1, switch every case.
"""
if outputs_exist(output_filenames):
tf.logging.info("Skipping generator because outputs files exist")
Expand All @@ -159,7 +162,8 @@ def generate_files(generator, output_filenames, max_cases=None):
break
example = to_example(case)
writers[shard].write(example.SerializeToString())
shard = (shard + 1) % num_shards
if counter % cycle_every_n == 0:
shard = (shard + 1) % num_shards

for writer in writers:
writer.close()
Expand Down Expand Up @@ -341,6 +345,7 @@ def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size,
"""Generate a vocabulary from the datasets in sources."""

def generate():
"""Generate lines for vocabulary generation."""
tf.logging.info("Generating vocab from: %s", str(sources))
for source in sources:
url = source[0]
Expand Down
72 changes: 37 additions & 35 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from __future__ import division
from __future__ import print_function

from collections import deque

import functools
import os

# Dependency imports

import gym

from tensor2tensor.data_generators import problem
Expand Down Expand Up @@ -62,9 +63,7 @@ def num_target_frames(self):
return 1

def eval_metrics(self):
eval_metrics = [
metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ,
metrics.Metrics.NEG_LOG_PERPLEXITY]
eval_metrics = [metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ]
return eval_metrics

@property
Expand Down Expand Up @@ -108,6 +107,10 @@ def num_rewards(self):
def num_steps(self):
raise NotImplementedError()

@property
def total_number_of_frames(self):
return self.num_steps

@property
def min_reward(self):
raise NotImplementedError()
Expand All @@ -126,13 +129,13 @@ def hparams(self, defaults, unused_model_hparams):
p.target_space_id = problem.SpaceID.IMAGE

def generate_samples(self, data_dir, tmp_dir, unused_dataset_split):
next_obs = self.env.reset()
next_observation = self.env.reset()
for _ in range(self.num_steps):
observation = next_obs
observation = next_observation
action = self.get_action(observation)
next_obs, reward, done, _ = self.env.step(action)
next_observation, reward, done, _ = self.env.step(action)
if done:
next_obs = self.env.reset()
next_observation = self.env.reset()
yield {"frame": observation,
"action": [action],
"done": [done],
Expand Down Expand Up @@ -184,23 +187,22 @@ class GymDiscreteProblemWithAgent(GymPongRandom5k):
def __init__(self, *args, **kwargs):
super(GymDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
self._env = None
self.history_size = 2
self.debug_dump_frames_path = "debug_frames_env"

# defaults
self.environment_spec = lambda: gym.make("PongDeterministic-v4")
self.in_graph_wrappers = [(atari.MaxAndSkipWrapper, {"skip": 4})]
self.in_graph_wrappers = []
self.collect_hparams = rl.atari_base()
self.settable_num_steps = 1000
self.settable_num_steps = 20000
self.simulated_environment = None
self.warm_up = 70
self.warm_up = 10

@property
def num_steps(self):
return self.settable_num_steps

def _setup(self):
in_graph_wrappers = [(atari.ShiftRewardWrapper, {"add_value": 2}),
(atari.MemoryWrapper, {})] + self.in_graph_wrappers
in_graph_wrappers = [(atari.MemoryWrapper, {})] + self.in_graph_wrappers
env_hparams = tf.contrib.training.HParams(
in_graph_wrappers=in_graph_wrappers,
simulated_environment=self.simulated_environment)
Expand Down Expand Up @@ -229,41 +231,41 @@ def _setup(self):

self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size()
self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()
self.history_buffer = deque(maxlen=self.history_size+1)

def restore_networks(self, sess):
if FLAGS.agent_policy_path:
model_saver = tf.train.Saver(
tf.global_variables(".*network_parameters.*"))
tf.global_variables(".*network_parameters.*"))
model_saver.restore(sess, FLAGS.agent_policy_path)

def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
self._setup()
self.debug_dump_frames_path = os.path.join(
data_dir, self.debug_dump_frames_path)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
self.restore_networks(sess)

# Actions are shifted by 1 by MemoryWrapper, compensate here.
avilable_data_size = sess.run(self.avilable_data_size_op)
if avilable_data_size < 1:
sess.run(self.collect_trigger_op)
pieces_generated = 0
observ, reward, _, _ = sess.run(self.data_get_op)
while pieces_generated < self.num_steps + self.warm_up:
avilable_data_size = sess.run(self.avilable_data_size_op)
if avilable_data_size > 0:
observ, reward, action, _ = sess.run(self.data_get_op)
self.history_buffer.append(observ)

if len(self.history_buffer) == self.history_size + 1:
pieces_generated += 1
ret_dict = {"image/encoded": [observ],
"image/format": ["png"],
"image/height": [self.frame_height],
"image/width": [self.frame_width],
"action": [int(action)],
"done": [int(False)],
"reward": [int(reward) - self.min_reward]}
if pieces_generated > self.warm_up:
yield ret_dict
else:
if avilable_data_size < 1:
sess.run(self.collect_trigger_op)
next_observ, next_reward, action, _ = sess.run(self.data_get_op)
yield {"image/encoded": [observ],
"image/format": ["png"],
"image/height": [self.frame_height],
"image/width": [self.frame_width],
"action": [int(action)],
"done": [int(False)],
"reward": [int(reward) - self.min_reward]}
pieces_generated += 1
observ, reward = next_observ, next_reward


@registry.register_problem
Expand All @@ -273,7 +275,7 @@ class GymSimulatedDiscreteProblemWithAgent(GymDiscreteProblemWithAgent):
def __init__(self, *args, **kwargs):
super(GymSimulatedDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
self.simulated_environment = True
self.debug_dump_frames_path = "/tmp/t2t_debug_dump_frames"
self.debug_dump_frames_path = "debug_frames_sim"

def restore_networks(self, sess):
super(GymSimulatedDiscreteProblemWithAgent, self).restore_networks(sess)
Expand Down
29 changes: 28 additions & 1 deletion tensor2tensor/data_generators/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def preprocess_example(self, example, mode, _):

@registry.register_problem
class ImageImagenet64Gen(ImageImagenet):
"""Cifar-10 Tune."""
"""Imagenet 64 from the pixen cnn paper"""

@property
def train_shards(self):
Expand Down Expand Up @@ -264,6 +264,33 @@ def preprocess_example(self, example, mode, hparams):
return example


@registry.register_problem
class ImageImagenet32Small(ImageImagenet):
"""Imagenet small from the pixel cnn paper"""

@property
def is_small(self):
return False # Modalities like for CIFAR.

@property
def num_classes(self):
return 1000

@property
def train_shards(self):
return 1024

@property
def dev_shards(self):
return 10

def preprocess_example(self, example, mode, unused_hparams):
example["inputs"].set_shape([_IMAGENET_SMALL_IMAGE_SIZE,
_IMAGENET_SMALL_IMAGE_SIZE, 3])
example["inputs"] = tf.to_int64(example["inputs"])
return example


@registry.register_problem
class ImageImagenet64(ImageImagenet32):
"""Imagenet rescaled to 64x64."""
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,5 +143,5 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
for sample in samples:
sample['targets'] = self.generate_targets(sample['targets'],
sample['context'])
if not sample['targets']:
if sample['targets']:
yield sample
Loading