Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6259b74
Rename RevTransformer to TransformerRevnet
Aug 17, 2017
4a8a715
Big fix in greedy infer when loss is None
Aug 17, 2017
3dd2ec6
Move to Datasets for input pipeline and bump TF requirement to v1.3.0
Aug 17, 2017
0aeda96
Fix preprocess_examples signature in image problems
Aug 17, 2017
014f222
Add rev_block test with conv and batch norm and add bucket_by_sequenc…
Aug 18, 2017
d515682
Minor updates to tensor2tensor README: fixes path of common_hparams.py
a-googler Aug 19, 2017
bc713b4
Better performances for the expert mask computation.
a-googler Aug 21, 2017
775e6c7
Port CelebA dataset to Problem and add landmarks and attrs
Aug 21, 2017
8ac4ca8
Reverting dataset change until more thoroughly tested
Aug 21, 2017
3075477
Add store_to_file option to Text2TextEncoder and minor documentation …
a-googler Aug 21, 2017
4fc251b
Fix reshape bug when an expert receive a batch of size zero
a-googler Aug 21, 2017
0f33a8d
Parameterize the number of encoder and decoder layers for the Transfo…
alexyku Aug 22, 2017
dd66a03
Port to Dataset (again) with tests
Aug 22, 2017
a237064
change decode_from_dataset to write outside the loop and extra cond f…
Aug 23, 2017
e5e79cd
Small corrections for TF 1.3, remove 8k wiki (vocab is 21k anyway), p…
Aug 23, 2017
be7a446
First stab at "diet" variables.
nshazeer Aug 23, 2017
31b688a
Add hparams to control the attention k,v,q size and add default base …
a-googler Aug 23, 2017
7a949f0
Add a more memory-efficient version of matmul->softmax->cross_entropy…
nshazeer Aug 23, 2017
6db9fa4
calling empty preprocessing sequence by name
a-googler Aug 24, 2017
8194a07
Fix bug in TokenTextEncoder where saving and loading a vocab wouldn't…
a-googler Aug 25, 2017
bde3499
Illustrate how to call T2T models from raw TF session, small cleanup.
Aug 25, 2017
e75c118
Decrease loss for small batches to defend against the new dataset API.
nshazeer Aug 25, 2017
92c1fa6
Corrections for the Dataset API, play with VAEs.
Aug 25, 2017
8e7a5d9
Make diet variables more generic
Aug 25, 2017
b52130f
Open source IPython Notebook for creating visualization from the Tran…
a-googler Aug 25, 2017
f4a3ac9
v1.1.10
Aug 25, 2017
a561551
Big release bump, small corrections.
Aug 25, 2017
ab1e766
Add beam search test.
a-googler Aug 25, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ and are encoded in
[`tf.contrib.training.HParams`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py)
objects. The `HParams` are available to both the problem specification and the
model. A basic set of hyperparameters are defined in
[`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/models/common_hparams.py)
[`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/common_hparams.py)
and hyperparameter set functions can compose other hyperparameter set functions.

### Trainer
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.1.9',
version='1.2.0',
description='Tensor2Tensor',
author='Google Inc.',
author_email='no-reply@google.com',
Expand All @@ -26,8 +26,8 @@
'six',
],
extras_require={
'tensorflow': ['tensorflow>=1.2.0rc1'],
'tensorflow_gpu': ['tensorflow-gpu>=1.2.0rc1'],
'tensorflow': ['tensorflow>=1.3.0'],
'tensorflow_gpu': ['tensorflow-gpu>=1.3.0'],
},
tests_require=['nose'],
test_suite='nose.collector',
Expand Down
4 changes: 0 additions & 4 deletions tensor2tensor/bin/t2t-datagen
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
from tensor2tensor.data_generators import audio
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import image
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import snli
from tensor2tensor.data_generators import wmt
Expand Down Expand Up @@ -106,9 +105,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True),
lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True)
),
"image_celeba_tune": (
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
"inference_snli32k": (
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
Expand Down
3 changes: 2 additions & 1 deletion tensor2tensor/data_generators/algorithmic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

# Dependency imports

from six.moves import xrange
from six.moves import xrange # pylint: disable=redefined-builtin

from tensor2tensor.data_generators import algorithmic

import tensorflow as tf
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/all_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@
pass
# pylint: enable=g-import-not-at-top
# pylint: enable=unused-import

25 changes: 21 additions & 4 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,24 @@ def gunzip_file(gz_path, new_path):

def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
generator_fn):
"""Inner implementation for vocab generators."""
vocab_filepath = os.path.join(data_dir, vocab_filename)
if tf.gfile.Exists(vocab_filepath):
"""Inner implementation for vocab generators.

Args:
data_dir: The base directory where data and vocab files are stored. If None,
then do not save the vocab even if it doesn't exist.
vocab_filename: relative filename where vocab file is stored
vocab_size: target size of the vocabulary constructed by SubwordTextEncoder
generator_fn: a generator that produces tokens from the vocabulary

Returns:
A SubwordTextEncoder vocabulary object.
"""
if data_dir is None:
vocab_filepath = None
else:
vocab_filepath = os.path.join(data_dir, vocab_filename)

if vocab_filepath is not None and tf.gfile.Exists(vocab_filepath):
tf.logging.info("Found vocab file: %s", vocab_filepath)
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
return vocab
Expand All @@ -316,7 +331,9 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,

vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
vocab_size, token_counts, 1, 1e3)
vocab.store_to_file(vocab_filepath)

if vocab_filepath is not None:
vocab.store_to_file(vocab_filepath)
return vocab


Expand Down
174 changes: 133 additions & 41 deletions tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,137 @@ def example_reading_spec(self, label_key=None):
return data_fields, data_items_to_decoders


# French street names dataset.
@registry.register_problem("image_celeba_tune")
class ImageCeleba(ImageProblem):
"""CelebA dataset, aligned and cropped images."""
IMG_DATA = ("img_align_celeba.zip",
"https://drive.google.com/uc?export=download&"
"id=0B7EVK8r0v71pZjFTYXZWM3FlRnM")
LANDMARKS_DATA = ("celeba_landmarks_align",
"https://drive.google.com/uc?export=download&"
"id=0B7EVK8r0v71pd0FJY3Blby1HUTQ")
ATTR_DATA = ("celeba_attr", "https://drive.google.com/uc?export=download&"
"id=0B7EVK8r0v71pblRyaVFSWGxPY0U")

LANDMARK_HEADINGS = ("lefteye_x lefteye_y righteye_x righteye_y "
"nose_x nose_y leftmouth_x leftmouth_y rightmouth_x "
"rightmouth_y").split()
ATTR_HEADINGS = (
"5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs "
"Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair "
"Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair "
"Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache "
"Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline "
"Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings "
"Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young"
).split()

def preprocess_examples(self, examples, unused_mode, unused_hparams):

def resize(img, size):
return tf.to_int64(
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))

inputs = examples["inputs"]
# Remove boundaries in CelebA images. Remove 40 pixels each side
# vertically and 20 pixels each side horizontally.
inputs = tf.image.crop_to_bounding_box(inputs, 40, 20, 218 - 80, 178 - 40)
examples["inputs"] = resize(inputs, 8)
examples["targets"] = resize(inputs, 32)
return examples

def hparams(self, defaults, model_hparams):
p = defaults
p.input_modality = {"inputs": ("image:identity_no_pad", None)}
p.target_modality = ("image:identity_no_pad", None)
p.batch_size_multiplier = 256
p.max_expected_batch_size_per_shard = 4
p.input_space_id = 1
p.target_space_id = 1

def generator(self, tmp_dir, how_many, start_from=0):
"""Image generator for CELEBA dataset.

Args:
tmp_dir: path to temporary storage directory.
how_many: how many images and labels to generate.
start_from: from which image to start.

Yields:
A dictionary representing the images with the following fields:
* image/encoded: the string encoding the image as JPEG,
* image/format: the string "jpeg" representing image format,
"""
out_paths = []
for fname, url in [self.IMG_DATA, self.LANDMARKS_DATA, self.ATTR_DATA]:
path = generator_utils.maybe_download_from_drive(tmp_dir, fname, url)
out_paths.append(path)

img_path, landmarks_path, attr_path = out_paths # pylint: disable=unbalanced-tuple-unpacking
unzipped_folder = img_path[:-4]
if not tf.gfile.Exists(unzipped_folder):
zipfile.ZipFile(img_path, "r").extractall(tmp_dir)

with tf.gfile.Open(landmarks_path) as f:
landmarks_raw = f.read()

with tf.gfile.Open(attr_path) as f:
attr_raw = f.read()

def process_landmarks(raw_data):
landmarks = {}
lines = raw_data.split("\n")
headings = lines[1].strip().split()
for line in lines[2:-1]:
values = line.strip().split()
img_name = values[0]
landmark_values = [int(v) for v in values[1:]]
landmarks[img_name] = landmark_values
return landmarks, headings

def process_attrs(raw_data):
attrs = {}
lines = raw_data.split("\n")
headings = lines[1].strip().split()
for line in lines[2:-1]:
values = line.strip().split()
img_name = values[0]
attr_values = [int(v) for v in values[1:]]
attrs[img_name] = attr_values
return attrs, headings

img_landmarks, _ = process_landmarks(landmarks_raw)
img_attrs, _ = process_attrs(attr_raw)

image_files = tf.gfile.Glob(unzipped_folder + "/*.jpg")
for filename in image_files[start_from:start_from + how_many]:
img_name = os.path.basename(filename)
landmarks = img_landmarks[img_name]
attrs = img_attrs[img_name]

with tf.gfile.Open(filename, "r") as f:
encoded_image_data = f.read()
yield {
"image/encoded": [encoded_image_data],
"image/format": ["jpeg"],
"attributes": attrs,
"landmarks": landmarks,
}

@property
def train_shards(self):
return 100

@property
def dev_shards(self):
return 10

def generate_data(self, data_dir, tmp_dir, task_id=-1):
generator_utils.generate_dataset_and_shuffle(
self.generator(tmp_dir, 162770), # train
self.training_filepaths(data_dir, self.train_shards, shuffled=False),
self.generator(tmp_dir, 19867, 162770), # dev
self.dev_filepaths(data_dir, self.dev_shards, shuffled=False))


@registry.register_problem
Expand Down Expand Up @@ -199,7 +329,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
"instructions at https://github.com/tensorflow/models/blob/master"
"/inception/README.md#getting-started")

def preprocess_examples(self, examples, mode):
def preprocess_examples(self, examples, mode, _):
return imagenet_preprocess_examples(examples, mode)


Expand Down Expand Up @@ -638,7 +768,7 @@ def train_shards(self):
def dev_shards(self):
return 10

def preprocess_examples(self, examples, mode):
def preprocess_examples(self, examples, mode, _):
return imagenet_preprocess_examples(examples, mode)

def generator(self, data_dir, tmp_dir, is_training):
Expand Down Expand Up @@ -700,41 +830,3 @@ class ImageMsCocoTokens32k(ImageMsCocoTokens8k):
@property
def targeted_vocab_size(self):
return 2**15 # 32768


# URL and filename for CELEBA data.
_CELEBA_NAME = "img_align_celeba"
_CELEBA_URL = "https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM"


def _get_celeba(directory):
"""Download and extract CELEBA to directory unless it is there."""
# path = os.path.join(directory, _CELEBA_NAME)
path = generator_utils.maybe_download_from_drive(directory, _CELEBA_NAME,
_CELEBA_URL)
if not tf.gfile.Exists(path):
zipfile.ZipFile(path + ".zip", "r").extractall(directory)


def celeba_generator(tmp_dir, how_many, start_from=0):
"""Image generator for CELEBA dataset.

Args:
tmp_dir: path to temporary storage directory.
how_many: how many images and labels to generate.
start_from: from which image to start.

Yields:
A dictionary representing the images with the following fields:
* image/encoded: the string encoding the image as JPEG,
* image/format: the string "jpeg" representing image format,
"""
_get_celeba(tmp_dir)
image_files = tf.gfile.Glob(os.path.join(tmp_dir, _CELEBA_NAME) + "/*.jpg")
for filename in image_files[start_from:start_from + how_many]:
with tf.gfile.Open(filename, "r") as f:
encoded_image_data = f.read()
yield {
"image/encoded": [encoded_image_data],
"image/format": ["jpeg"],
}
3 changes: 2 additions & 1 deletion tensor2tensor/data_generators/lm1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

# Dependency imports

from six.moves import xrange
from six.moves import xrange # pylint: disable=redefined-builtin

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import tokenizer
Expand Down
28 changes: 15 additions & 13 deletions tensor2tensor/data_generators/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ class Problem(object):
* generate_data(data_dir, tmp_dir)
- Generate training and dev datasets into data_dir.
- Additonal files, e.g. vocabulary files, should also be written to
data_dir.
data_dir. Vocab files are newline-separated files with each line
containing a token. The standard convention for the filename is to
set it to be
${Problem.vocab_name}.${Problem.targeted_vocab_size}
- Downloads and other files can be written to tmp_dir
- If you have a training and dev generator, you can generate the
training and dev datasets with
Expand Down Expand Up @@ -200,22 +203,22 @@ def training_filepaths(self, data_dir, num_shards, shuffled):
file_basename = self.dataset_filename()
if not shuffled:
file_basename += generator_utils.UNSHUFFLED_SUFFIX
return generator_utils.train_data_filenames(
file_basename, data_dir, num_shards)
return generator_utils.train_data_filenames(file_basename, data_dir,
num_shards)

def dev_filepaths(self, data_dir, num_shards, shuffled):
file_basename = self.dataset_filename()
if not shuffled:
file_basename += generator_utils.UNSHUFFLED_SUFFIX
return generator_utils.dev_data_filenames(
file_basename, data_dir, num_shards)
return generator_utils.dev_data_filenames(file_basename, data_dir,
num_shards)

def test_filepaths(self, data_dir, num_shards, shuffled):
file_basename = self.dataset_filename()
if not shuffled:
file_basename += generator_utils.UNSHUFFLED_SUFFIX
return generator_utils.test_data_filenames(
file_basename, data_dir, num_shards)
return generator_utils.test_data_filenames(file_basename, data_dir,
num_shards)

def __init__(self, was_reversed=False, was_copy=False):
"""Create a Problem.
Expand Down Expand Up @@ -412,10 +415,8 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
generator_utils.shuffle_dataset(all_paths)
else:
generator_utils.generate_dataset_and_shuffle(
self.generator(data_dir, tmp_dir, True),
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
self.generator(data_dir, tmp_dir, False),
self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False))
self.generator(data_dir, tmp_dir, True), train_paths,
self.generator(data_dir, tmp_dir, False), dev_paths)

def feature_encoders(self, data_dir):
if self.is_character_level:
Expand All @@ -435,8 +436,9 @@ def hparams(self, defaults, unused_model_hparams):

if self.has_inputs:
source_vocab_size = self._encoders["inputs"].vocab_size
p.input_modality = {"inputs": (registry.Modalities.SYMBOL,
source_vocab_size)}
p.input_modality = {
"inputs": (registry.Modalities.SYMBOL, source_vocab_size)
}
target_vocab_size = self._encoders["targets"].vocab_size
p.target_modality = (registry.Modalities.SYMBOL, target_vocab_size)
if self.has_inputs:
Expand Down
Loading