From ff360cc5431c75cfad3258a37ad1adc9dda80208 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 2 Aug 2017 17:26:03 -0700 Subject: [PATCH 01/12] Add layers/__init__.py and update gitignore for nose PiperOrigin-RevId: 164063212 --- .gitignore | 1 - tensor2tensor/data_generators/all_problems.py | 1 - tensor2tensor/data_generators/cipher.py | 213 ------------------ tensor2tensor/layers/__init__.py | 15 ++ 4 files changed, 15 insertions(+), 215 deletions(-) delete mode 100644 tensor2tensor/data_generators/cipher.py diff --git a/.gitignore b/.gitignore index 362753caa..c9dd3db88 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,6 @@ _pycache__/ # Python egg metadata, regenerated from source files by setuptools. /*.egg-info -/*.egg # PyPI distribution artifacts. build/ diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 843bd0a66..9be133a61 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -22,7 +22,6 @@ from tensor2tensor.data_generators import algorithmic from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import audio -from tensor2tensor.data_generators import cipher from tensor2tensor.data_generators import image from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import ptb diff --git a/tensor2tensor/data_generators/cipher.py b/tensor2tensor/data_generators/cipher.py deleted file mode 100644 index 546ba1739..000000000 --- a/tensor2tensor/data_generators/cipher.py +++ /dev/null @@ -1,213 +0,0 @@ -from collections import deque -import numpy as np - -from tensor2tensor.data_generators import problem, algorithmic -from tensor2tensor.utils import registry - - -@registry.register_problem -class CipherShift5(algorithmic.AlgorithmicProblem): - - @property - def num_symbols(self): - return 5 - - @property - def distribution(self): - return [0.4, 0.3, 0.2, 0.08, 0.02] - - @property - def shift(self): - return 1 - - @property - def train_generator(self): - """Generator; takes 3 args: nbr_symbols, max_length, nbr_cases.""" - - def _gen(nbr_symbols, max_length, nbr_cases): - plain_vocab = range(nbr_symbols) - indices = generate_plaintext_random(plain_vocab, self.distribution, - nbr_cases, max_length) - codes = encipher_shift(indices, plain_vocab, self.shift) - - for plain, code in zip(indices, codes): - yield { - "X": plain, - "Y": code, - } - - return _gen - - @property - def train_length(self): - return 100 - - @property - def dev_length(self): - return self.train_length - - -@registry.register_problem -class CipherVigenere5(algorithmic.AlgorithmicProblem): - - @property - def num_symbols(self): - return 5 - - @property - def distribution(self): - return [0.4, 0.3, 0.2, 0.08, 0.02] - - @property - def key(self): - return [1, 3] - - @property - def train_generator(self): - """Generator; takes 3 args: nbr_symbols, max_length, nbr_cases.""" - - def _gen(nbr_symbols, max_length, nbr_cases): - plain_vocab = range(nbr_symbols) - indices = generate_plaintext_random(plain_vocab, self.distribution, - nbr_cases, max_length) - codes = encipher_vigenere(indices, plain_vocab, self.key) - - for plain, code in zip(indices, codes): - yield { - "X": plain, - "Y": code, - } - - return _gen - - @property - def train_length(self): - return 200 - - @property - def dev_length(self): - return self.train_length - - -@registry.register_problem -class CipherShift200(CipherShift5): - - @property - def num_symbols(self): - return 200 - - @property - def distribution(self): - vals = range(self.num_symbols) - val_sum = sum(vals) - return [v / val_sum for v in vals] - - -@registry.register_problem -class CipherVigenere200(CipherVigenere5): - - @property - def num_symbols(self): - return 200 - - @property - def distribution(self): - vals = range(self.num_symbols) - val_sum = sum(vals) - return [v / val_sum for v in vals] - - @property - def key(self): - return [1, 3] - - -class Layer(): - """A single layer for shift""" - - def __init__(self, vocab, shift): - """Initialize shift layer - - Args: - vocab (list of String): the vocabulary - shift (Integer): the amount of shift apply to the alphabet. Positive number implies - shift to the right, negative number implies shift to the left. - """ - self.shift = shift - alphabet = vocab - shifted_alphabet = deque(alphabet) - shifted_alphabet.rotate(shift) - self.encrypt = dict(zip(alphabet, list(shifted_alphabet))) - self.decrypt = dict(zip(list(shifted_alphabet), alphabet)) - - def encrypt_character(self, character): - return self.encrypt[character] - - def decrypt_character(self, character): - return self.decrypt[character] - - -def generate_plaintext_random(plain_vocab, distribution, train_samples, - length): - """Generates samples of text from the provided vocabulary. - Returns: - train_indices (np.array of Integers): random integers generated for training. - shape = [num_samples, length] - test_indices (np.array of Integers): random integers generated for testing. - shape = [num_samples, length] - plain_vocab (list of Integers): unique vocabularies. - """ - if distribution is not None: - assert len(distribution) == len(plain_vocab) - - train_indices = np.random.choice( - range(len(plain_vocab)), (train_samples, length), p=distribution) - - return train_indices - - -def encipher_shift(plaintext, plain_vocab, shift): - """Encrypt plain text with a single shift layer - Args: - plaintext (list of list of Strings): a list of plain text to encrypt. - plain_vocab (list of Integer): unique vocabularies being used. - shift (Integer): number of shift, shift to the right if shift is positive. - Returns: - ciphertext (list of Strings): encrypted plain text. - """ - ciphertext = [] - cipher = Layer(plain_vocab, shift) - - for i, sentence in enumerate(plaintext): - cipher_sentence = [] - for j, character in enumerate(sentence): - encrypted_char = cipher.encrypt_character(character) - cipher_sentence.append(encrypted_char) - ciphertext.append(cipher_sentence) - - return ciphertext - - -def encipher_vigenere(plaintext, plain_vocab, key): - """Encrypt plain text with given key - Args: - plaintext (list of list of Strings): a list of plain text to encrypt. - plain_vocab (list of Integer): unique vocabularies being used. - key (list of Integer): key to encrypt cipher using Vigenere table. - Returns: - ciphertext (list of Strings): encrypted plain text. - """ - ciphertext = [] - # generate Vigenere table - layers = [] - for i in range(len(plain_vocab)): - layers.append(Layer(plain_vocab, i)) - - for i, sentence in enumerate(plaintext): - cipher_sentence = [] - for j, character in enumerate(sentence): - key_idx = key[j % len(key)] - encrypted_char = layers[key_idx].encrypt_character(character) - cipher_sentence.append(encrypted_char) - ciphertext.append(cipher_sentence) - - return ciphertext diff --git a/tensor2tensor/layers/__init__.py b/tensor2tensor/layers/__init__.py index e69de29bb..3f714ce1f 100644 --- a/tensor2tensor/layers/__init__.py +++ b/tensor2tensor/layers/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + From f47930391a88974ff5253e7068377025f3b27ccb Mon Sep 17 00:00:00 2001 From: Ashish Vaswani Date: Wed, 2 Aug 2017 22:07:37 -0700 Subject: [PATCH 02/12] Fix image problem preprocess_examples signature PiperOrigin-RevId: 164081124 --- tensor2tensor/data_generators/image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index d70d9339e..f61f85b54 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -423,7 +423,7 @@ def cifar10_generator(tmp_dir, training, how_many, start_from=0): @registry.register_problem class ImageCifar10Tune(ImageMnistTune): - def preprocess_examples(self, examples, mode): + def preprocess_examples(self, examples, mode, hparams): if mode == tf.contrib.learn.ModeKeys.TRAIN: examples["inputs"] = common_layers.cifar_image_augmentation( examples["inputs"]) @@ -449,7 +449,7 @@ def generator(self, data_dir, tmp_dir, is_training): @registry.register_problem class ImageCifar10Plain(ImageCifar10): - def preprocess_examples(self, examples, mode): + def preprocess_examples(self, examples, mode, hparams): return examples From fec81d2f60892759b9d55f2a0cbab75d3a9ce8cb Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 3 Aug 2017 12:29:07 -0700 Subject: [PATCH 03/12] Rm FLAGS from input fn builder and fix placeholder logic PiperOrigin-RevId: 164164402 --- tensor2tensor/utils/input_fn_builder.py | 25 ++++++++++++++----------- tensor2tensor/utils/trainer_utils.py | 8 ++++++-- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/tensor2tensor/utils/input_fn_builder.py b/tensor2tensor/utils/input_fn_builder.py index 1fac64c8b..79a765ca2 100644 --- a/tensor2tensor/utils/input_fn_builder.py +++ b/tensor2tensor/utils/input_fn_builder.py @@ -27,15 +27,14 @@ import tensorflow as tf -# TODO(rsepassi): Rm dep on FLAGS here -FLAGS = tf.flags.FLAGS - def build_input_fn(mode, hparams, data_file_patterns=None, num_datashards=None, - fixed_problem=None): + fixed_problem=None, + worker_replicas=None, + worker_id=None): """Provides input to the graph, either from disk or via a placeholder. This function produces an input function that will feed data into @@ -58,6 +57,10 @@ def build_input_fn(mode, num_datashards: An integer. fixed_problem: An integer indicating the problem to fetch data for, or None if the input is to be randomly selected. + worker_replicas: int, number of worker replicas. Used in multiproblem + setting with hparams.problem_choice == distributed. + worker_id: int, id of this worker replica. Used in multiproblem setting with + hparams.problem_choice == distributed. Returns: A function that returns a dictionary of features and the target labels. @@ -78,7 +81,7 @@ def input_fn(): Raises: ValueError: if one of the parameters has an unsupported value. """ - problem_count, batches = len(data_file_patterns), [] + problem_count, batches = len(hparams.problems), [] with tf.name_scope("input_reader"): for n in xrange(problem_count): if fixed_problem is not None and n != fixed_problem: @@ -89,9 +92,9 @@ def input_fn(): with tf.device("/cpu:0"): # Input reading on CPU capacity = p_hparams.max_expected_batch_size_per_shard capacity *= num_datashards - examples = data_reader.input_pipeline(problem_instance, - data_file_patterns[n], - capacity, mode, hparams) + examples = data_reader.input_pipeline( + problem_instance, data_file_patterns and data_file_patterns[n], + capacity, mode, hparams) feature_map = data_reader.batch_examples( examples, data_reader.hparams_to_batching_scheme( @@ -149,9 +152,9 @@ def input_fn(): tf.reshape(loss_moving_avgs, [1, -1]), 1) problem_choice = tf.to_int32(tf.squeeze(problem_choice)) elif hparams.problem_choice == "distributed": - assert FLAGS.worker_replicas >= problem_count - assert FLAGS.worker_replicas % problem_count == 0 - problem_choice = tf.to_int32(FLAGS.worker_id % problem_count) + assert worker_replicas >= problem_count + assert worker_replicas % problem_count == 0 + problem_choice = tf.to_int32(worker_id % problem_count) else: raise ValueError( "Value of hparams.problem_choice is %s and must be " diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index c5f3296ee..9e869c15c 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -177,14 +177,18 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name): hparams=hparams, data_file_patterns=get_data_filepatterns(data_dir, tf.contrib.learn.ModeKeys.TRAIN), - num_datashards=num_datashards) + num_datashards=num_datashards, + worker_replicas=FLAGS.worker_replicas, + worker_id=FLAGS.worker_id) eval_input_fn = input_fn_builder.build_input_fn( mode=tf.contrib.learn.ModeKeys.EVAL, hparams=hparams, data_file_patterns=get_data_filepatterns(data_dir, tf.contrib.learn.ModeKeys.EVAL), - num_datashards=num_datashards) + num_datashards=num_datashards, + worker_replicas=FLAGS.worker_replicas, + worker_id=FLAGS.worker_id) estimator = tf.contrib.learn.Estimator( model_fn=model_builder.build_model_fn(model_name, hparams=hparams), model_dir=output_dir, From 21404237d7c12c9a650603c8ab6391cd1a5438b4 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 3 Aug 2017 13:59:04 -0700 Subject: [PATCH 04/12] Add PEPTIDE SpaceID and enable TokenTextEncoder to take a list of tokens PiperOrigin-RevId: 164177483 --- tensor2tensor/data_generators/problem.py | 2 + tensor2tensor/data_generators/text_encoder.py | 98 ++++++++++++------- 2 files changed, 64 insertions(+), 36 deletions(-) diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 3d30ec239..72334b76d 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -84,6 +84,8 @@ class SpaceID(object): REAL = 24 # Images IMAGE = 25 + # Peptide + PEPTIDE = 26 class Problem(object): diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index cd6ca0eea..ad9c04c96 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -48,7 +48,6 @@ else: RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] - # Regular expression for unescaping token strings. # '\u' is converted to '_' # '\\' is converted to '\' @@ -154,14 +153,21 @@ def vocab_size(self): class TokenTextEncoder(TextEncoder): - """Encoder based on a user-supplied vocabulary.""" + """Encoder based on a user-supplied vocabulary (file or list).""" - def __init__(self, vocab_filename, reverse=False, + def __init__(self, + vocab_filename, + reverse=False, + vocab_list=None, num_reserved_ids=NUM_RESERVED_TOKENS): - """Initialize from a file, one token per line.""" + """Initialize from a file or list, one token per line.""" super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) self._reverse = reverse - self._load_vocab_from_file(vocab_filename) + if vocab_filename: + self._init_vocab_from_file(vocab_filename) + else: + assert vocab_list is not None + self._init_vocab_from_list(vocab_list) def encode(self, sentence): """Converts a space-separated string of tokens to a list of ids.""" @@ -179,22 +185,40 @@ def vocab_size(self): def _safe_id_to_token(self, idx): return self._id_to_token.get(idx, "ID_%d" % idx) - def _load_vocab_from_file(self, filename): + def _init_vocab_from_file(self, filename): """Load vocab from a file.""" - self._token_to_id = {} + + def token_gen(): + with tf.gfile.Open(filename) as f: + for line in f: + token = line.strip() + yield token + + self._init_vocab(token_gen()) + + def _init_vocab_from_list(self, vocab_list): + + def token_gen(): + for token in vocab_list: + yield token + + self._init_vocab(token_gen()) + + def _init_vocab(self, token_generator): + """Initialize vocabulary with tokens from token_generator.""" self._id_to_token = {} - for idx, tok in enumerate(RESERVED_TOKENS): - self._token_to_id[tok] = idx - self._id_to_token[idx] = tok + # Add reserved tokens + self._id_to_token.update(dict(list(enumerate(RESERVED_TOKENS)))) - token_start_idx = self._num_reserved_ids - with tf.gfile.Open(filename) as f: - for i, line in enumerate(f): - idx = token_start_idx + i - tok = line.strip() - self._token_to_id[tok] = idx - self._id_to_token[idx] = tok + token_id = len(RESERVED_TOKENS) + for token in token_generator: + self._id_to_token[token_id] = token + token_id += 1 + + # _token_to_id is the reverse of _id_to_token + self._token_to_id = dict([(v, k) + for k, v in six.iteritems(self._id_to_token)]) def _escape_token(token, alphabet): @@ -218,9 +242,7 @@ def _escape_token(token, alphabet): raise ValueError("Expected string type for token, got %s" % type(token)) token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") - ret = [ - c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) - for c in token] + ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token] return u"".join(ret) + "_" @@ -233,6 +255,7 @@ def _unescape_token(escaped_token): Returns: token: a unicode string """ + def match(m): if m.group(1) is None: return u"_" if m.group(0) == u"\\u" else u"\\" @@ -294,8 +317,8 @@ def encode(self, raw_text): Returns: a list of integers in the range [0, vocab_size) """ - return self._tokens_to_subtoken_ids(tokenizer.encode( - native_to_unicode(raw_text))) + return self._tokens_to_subtoken_ids( + tokenizer.encode(native_to_unicode(raw_text))) def decode(self, subtokens): """Converts a sequence of subtoken ids to a native string. @@ -305,8 +328,8 @@ def decode(self, subtokens): Returns: a native string """ - return unicode_to_native(tokenizer.decode( - self._subtoken_ids_to_tokens(subtokens))) + return unicode_to_native( + tokenizer.decode(self._subtoken_ids_to_tokens(subtokens))) @property def vocab_size(self): @@ -323,8 +346,9 @@ def _tokens_to_subtoken_ids(self, tokens): """ ret = [] for token in tokens: - ret.extend(self._escaped_token_to_subtoken_ids( - _escape_token(token, self._alphabet))) + ret.extend( + self._escaped_token_to_subtoken_ids( + _escape_token(token, self._alphabet))) return ret def _subtoken_ids_to_tokens(self, subtokens): @@ -386,7 +410,8 @@ def _escaped_token_to_subtoken_ids(self, escaped_token): """ return [ self._subtoken_string_to_id[subtoken] - for subtoken in self._escaped_token_to_subtoken_strings(escaped_token)] + for subtoken in self._escaped_token_to_subtoken_strings(escaped_token) + ] @classmethod def build_to_target_size(cls, @@ -414,17 +439,16 @@ def build_to_target_size(cls, ValueError: If `min_val` is greater than `max_val`. """ if min_val > max_val: - raise ValueError( - "Lower bound for the minimum token count " - "is greater than the upper bound.") + raise ValueError("Lower bound for the minimum token count " + "is greater than the upper bound.") def bisect(min_val, max_val): """Bisection to find the right size.""" present_count = (max_val + min_val) // 2 tf.logging.info("Trying min_count %d" % present_count) subtokenizer = cls() - subtokenizer.build_from_token_counts( - token_counts, present_count, num_iterations) + subtokenizer.build_from_token_counts(token_counts, present_count, + num_iterations) # If min_val == max_val, we can't do any better than this. if subtokenizer.vocab_size == target_size or min_val >= max_val: @@ -498,7 +522,7 @@ def build_from_token_counts(self, # Consider the candidates longest to shortest, so that if we accept # a longer subtoken string, we can decrement the counts of its prefixes. new_subtoken_strings = [] - for lsub in xrange(len(len_to_subtoken_strings)-1, 0, -1): + for lsub in xrange(len(len_to_subtoken_strings) - 1, 0, -1): subtoken_strings = len_to_subtoken_strings[lsub] for subtoken_string in subtoken_strings: count = subtoken_counts[subtoken_string] @@ -511,8 +535,8 @@ def build_from_token_counts(self, subtoken_counts[subtoken_string[:l]] -= count # Include the alphabet explicitly to guarantee all strings are encodable. - new_subtoken_strings.extend( - (subtoken_counts.get(a, 0), a) for a in self._alphabet) + new_subtoken_strings.extend((subtoken_counts.get(a, 0), a) + for a in self._alphabet) new_subtoken_strings.sort(reverse=True) # Reinitialize to the candidate vocabulary. @@ -535,7 +559,9 @@ def _init_subtokens_from_list(self, subtoken_strings, reserved=0): # check arbitrarily long strings. self._max_subtoken_len = max([len(s) for s in subtoken_strings]) self._subtoken_string_to_id = { - s: i+reserved for i, s in enumerate(subtoken_strings) if s} + s: i + reserved + for i, s in enumerate(subtoken_strings) if s + } def _init_alphabet_from_tokens(self, tokens): """Initialize alphabet from an iterable of token or subtoken strings.""" From 34a961f0d4f9fa38d8dddc9df1d3366b1d7703cf Mon Sep 17 00:00:00 2001 From: T2T Team Date: Thu, 3 Aug 2017 14:56:08 -0700 Subject: [PATCH 05/12] Add desc2code problem (From the OpenAI Description2Code dataset). PiperOrigin-RevId: 164187018 --- tensor2tensor/data_generators/all_problems.py | 1 + tensor2tensor/data_generators/desc2code.py | 246 ++++++++++++++++++ tensor2tensor/data_generators/problem.py | 2 + 3 files changed, 249 insertions(+) create mode 100644 tensor2tensor/data_generators/desc2code.py diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 9be133a61..af2030d89 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -22,6 +22,7 @@ from tensor2tensor.data_generators import algorithmic from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import audio +from tensor2tensor.data_generators import desc2code from tensor2tensor.data_generators import image from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import ptb diff --git a/tensor2tensor/data_generators/desc2code.py b/tensor2tensor/data_generators/desc2code.py new file mode 100644 index 000000000..52513e63c --- /dev/null +++ b/tensor2tensor/data_generators/desc2code.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for the Description2Code OpenAI data-set.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import random +import zipfile + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.utils import registry + +import tensorflow as tf + + +# End-of-sentence marker. +EOS = text_encoder.EOS_ID + +_DATASET_URL = "https://drive.google.com/uc?export=download&id=0Bz3fihKG133ceWNFQTQ5S0xhZUk" +_DATASET_FILENAME = "description2code_current.zip" +_DATASET_PB_PATH = "description2code_current/" + +_DESC_DIR_NAME = "description" +_CODE_PY_DIR_NAME = "solutions_python" + +_VOCAB_EN_FILENAME = "vocab_desc2code_tok_en" +_VOCAB_PY_FILENAME = "vocab_desc2code_tok_py" + +# Struct containing a coding problem (contains the paths to the descriptions +# and code files) +CodingPbInfo = collections.namedtuple("CodingPbInfo", "desc_file, code_files") + + +class Desc2CodeProblem(problem.Text2TextProblem): + """Base class for Description2Code problems.""" + + @property + def is_character_level(self): + return False + + @property + def num_shards(self): + return 100 + + @property + def use_subword_tokenizer(self): + return True + + +@registry.register_problem("desc2code_py") +class Desc2CodePyProblem(Desc2CodeProblem): + """Description2Code for python problem.""" + + @property + def targeted_vocab_size(self): + return 2**13 # 8192 + + @property + def input_space_id(self): + return problem.SpaceID.EN_TOK + + @property + def target_space_id(self): + return problem.SpaceID.PY_TOK + + @property + def vocab_input_filename(self): + return "{}.{}".format(_VOCAB_EN_FILENAME, self.targeted_vocab_size) + + @property + def vocab_target_filename(self): + return "{}.{}".format(_VOCAB_PY_FILENAME, self.targeted_vocab_size) + + def train_generator(self, data_dir, tmp_dir, train): + # Called twice: for train and test + + # Get the list of the training samples (coding challenge samples) + samples = list(generator_samples(tmp_dir)) + + # Split between train and dev + # Suffle to get problems from diverse sources (CodeChef and CodeForces) and + # dificulties in each set. + # Need to sort the samples first before shuffling (as walk() isn't + # deterministic) + samples.sort(key=lambda x: x.desc_file) # in-place + rng = random.Random(7531) # Local fixed seed + rng.shuffle(samples) # in-place + + # Train: 5019/5228 problems + # Dev: 209/5228 problems + len_samples = len(samples) + split = len_samples // 25 + samples = samples[split:] if train else samples[:split] + tf.logging.info("Number of samples for {}: {}/{}".format( + "train" if train else "dev", + len(samples), + len_samples + )) + + def generator_samples_content(get_source, get_target): + source, target = None, None + # Iterate over the coding samples + for sample in samples: + if get_source: + with tf.gfile.GFile(sample.desc_file, mode="r") as source_file: + source = source_file.read() + + if get_target: + # Each challenge can have multiple implementations (or none) + for code_file in sample.code_files: + with tf.gfile.GFile(code_file, mode="r") as target_file: + target = target_file.read() + yield source, target + elif sample.code_files: # Only take the source if a target exists + yield source, target + + def generator_source(): + for source, _ in generator_samples_content(True, False): + yield source.strip() + + def generator_target(): + for _, target in generator_samples_content(False, True): + yield target.strip() + + # Generate vocab for both source and target + + source_vocab = generator_utils.get_or_generate_vocab_inner( + data_dir=data_dir, + vocab_filename=self.vocab_input_filename, + vocab_size=self.targeted_vocab_size, + generator_fn=generator_source, + ) + + target_vocab = generator_utils.get_or_generate_vocab_inner( + data_dir=data_dir, + vocab_filename=self.vocab_target_filename, + vocab_size=self.targeted_vocab_size, + generator_fn=generator_target, + ) + + # Yield the training and testing samples + eos_list = [EOS] + for source, target in generator_samples_content(True, True): + source_ints = source_vocab.encode(source.strip()) + eos_list + target_ints = target_vocab.encode(target.strip()) + eos_list + yield { + "inputs": source_ints, + "targets": target_ints, + } + + def feature_encoders(self, data_dir): + source_vocab_filename = os.path.join(data_dir, self.vocab_input_filename) + target_vocab_filename = os.path.join(data_dir, self.vocab_target_filename) + source_token = text_encoder.SubwordTextEncoder(source_vocab_filename) + target_token = text_encoder.SubwordTextEncoder(target_vocab_filename) + return { + "inputs": source_token, + "targets": target_token, + } + + +# Utils functions + + +def generator_samples(tmp_dir): + """Generator for the dataset samples. + + If not present, download and extract the dataset. + + Args: + tmp_dir: path to the directory where to download the dataset. + + Yields: + A CodingPbInfo object containing the next challenge informations. + """ + # Step1: Download dataset (eventually) + data_zip_path = generator_utils.maybe_download_from_drive( + directory=tmp_dir, + filename=_DATASET_FILENAME, + url=_DATASET_URL, + ) + tf.logging.info("Data downloaded in: {}".format(data_zip_path)) + + # Step2: Extract dataset + # We could deduce _DATASET_PB_PATH from the zip file (instead of + # hardcoded path) + data_rootdir = os.path.join(tmp_dir, _DATASET_PB_PATH) + if not tf.gfile.Exists(data_rootdir): + with zipfile.ZipFile(data_zip_path, "r") as corpus_zip: + corpus_zip.extractall(tmp_dir) + # We could remove the extracted __MACOSX folder + tf.logging.info("Data extracted in: {}".format(tmp_dir)) + else: + tf.logging.info("Data already extracted in: {}".format(tmp_dir)) + + # Step3: Extract the problems list on the extracted folder + def contains_samples(subdir, dirs, files): # pylint: disable=unused-argument + """Check that the folder contains a problem.""" + return ( + _DESC_DIR_NAME in dirs and + _CODE_PY_DIR_NAME in dirs + ) + + def next_sample(subdir, dirs, files): # pylint: disable=unused-argument + """Return the filenames of the problem.""" + # More could be extracted (like the expected inputs/outputs + # pairs, the problem difficulty, the names of the algorithmic techniques + # needed) + desc_file = os.path.join(subdir, _DESC_DIR_NAME, "description.txt") + code_rootdir = os.path.join(subdir, _CODE_PY_DIR_NAME) + code_files = [ + f for f in tf.gfile.Glob(os.path.join(code_rootdir, "*.txt")) + ] + return CodingPbInfo( + desc_file=desc_file, + code_files=code_files + ) + + # The dataset contains problem from two different sources (CodeChef + # and CodeForces). Due to the limited number of samples, all problems from + # both sources are merged + for w in tf.gfile.Walk(data_rootdir): + if contains_samples(*w): + yield next_sample(*w) + diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 72334b76d..fb7e53cb7 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -86,6 +86,8 @@ class SpaceID(object): IMAGE = 25 # Peptide PEPTIDE = 26 + # Python + PY_TOK = 27 class Problem(object): From 95ee9e5b2e979c22ed81bf78dd62f7a6cb42de84 Mon Sep 17 00:00:00 2001 From: Noam Shazeer Date: Thu, 3 Aug 2017 15:20:06 -0700 Subject: [PATCH 06/12] added transformer_moe - a transformer model with mixtures-of-experts. PiperOrigin-RevId: 164190826 --- tensor2tensor/models/models.py | 1 + tensor2tensor/models/transformer_moe.py | 216 ++++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 tensor2tensor/models/transformer_moe.py diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index c2a904888..963975780 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -37,5 +37,6 @@ from tensor2tensor.models import slicenet from tensor2tensor.models import transformer from tensor2tensor.models import transformer_alternative +from tensor2tensor.models import transformer_moe from tensor2tensor.models import xception # pylint: enable=unused-import diff --git a/tensor2tensor/models/transformer_moe.py b/tensor2tensor/models/transformer_moe.py new file mode 100644 index 000000000..8072f2cf8 --- /dev/null +++ b/tensor2tensor/models/transformer_moe.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""transformer (attention seq-seq model) with mixtures of experts. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.models import transformer +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +@registry.register_model +class TransformerMoe(t2t_model.T2TModel): + """Attention net. See file docstring.""" + + def model_fn_body_sharded(self, sharded_features): + hparams = self._hparams + dp = self._data_parallelism + targets = sharded_features["targets"] + inputs = sharded_features["inputs"] + target_space = sharded_features["target_space_id"] + + inputs = dp(common_layers.flatten4d3d, inputs) + targets = dp(common_layers.flatten4d3d, targets) + + (encoder_input, encoder_self_attention_bias, + encoder_decoder_attention_bias) = dp( + transformer.transformer_prepare_encoder, + inputs, target_space, hparams) + (decoder_input, decoder_self_attention_bias) = dp( + transformer.transformer_prepare_decoder, targets, hparams) + residual_fn = transformer.get_residual_fn(hparams) + encoder_input = dp(tf.nn.dropout, encoder_input, + 1.0 - hparams.residual_dropout) + decoder_input = dp(tf.nn.dropout, decoder_input, + 1.0 - hparams.residual_dropout) + extra_loss = 0 + x = encoder_input + for layer in xrange(hparams.num_hidden_layers): + with tf.variable_scope("encoder_layer_%d" % layer): + with tf.variable_scope("encoder_self_attention"): + y = dp( + common_attention.multihead_attention, + x, + None, + encoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout) + x = dp(residual_fn, x, y) + with tf.variable_scope("ffn"): + if str(layer) in hparams.moe_layers_encoder.split(","): + y, loss = common_layers.moe_layer( + dp, self._ps_devices, x, + hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, + hparams.hidden_size, hparams.moe_hidden_size, hparams.moe_n1, + hparams.moe_n2, hparams.moe_loss_coef) + extra_loss += loss + else: + y = dp( + common_layers.conv_hidden_relu, + x, + hparams.filter_size, + hparams.hidden_size, + dropout=hparams.relu_dropout) + x = dp(residual_fn, x, y) + encoder_output = x + x = decoder_input + for layer in xrange(hparams.num_hidden_layers): + with tf.variable_scope("decoder_layer_%d" % layer): + with tf.variable_scope("decoder_self_attention"): + y = dp( + common_attention.multihead_attention, + x, + None, + decoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout) + x = dp(residual_fn, x, y) + with tf.variable_scope("encoder_decoder_attention"): + y = dp( + common_attention.multihead_attention, + x, + encoder_output, + encoder_decoder_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout) + x = dp(residual_fn, x, y) + with tf.variable_scope("ffn"): + if str(layer) in hparams.moe_layers_decoder.split(","): + y, loss = common_layers.moe_layer( + dp, self._ps_devices, x, + hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, + hparams.hidden_size, hparams.moe_hidden_size, hparams.moe_n1, + hparams.moe_n2, hparams.moe_loss_coef) + extra_loss += loss + else: + y = dp( + common_layers.conv_hidden_relu, + x, + hparams.filter_size, + hparams.hidden_size, + dropout=hparams.relu_dropout) + x = dp(residual_fn, x, y) + decoder_output = dp(tf.expand_dims, x, 2) + return decoder_output, extra_loss + + +@registry.register_hparams +def transformer_moe_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.norm_type = "layer" + hparams.hidden_size = 512 + hparams.batch_size = 4096 + hparams.max_length = 2001 + hparams.max_input_seq_length = 2000 + hparams.max_target_seq_length = 2000 + hparams.dropout = 0.0 + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.optimizer_adam_epsilon = 1e-9 + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.learning_rate_warmup_steps = 4000 + hparams.initializer_gain = 1.0 + hparams.num_hidden_layers = 5 + hparams.initializer = "uniform_unit_scaling" + hparams.weight_decay = 0.0 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + hparams.num_sampled_classes = 0 + hparams.label_smoothing = 0.0 + hparams.shared_embedding_and_softmax_weights = int(True) + + hparams.add_hparam("filter_size", 2048) # Add new ones like this. + # attention-related flags + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("attention_key_channels", 0) + hparams.add_hparam("attention_value_channels", 0) + hparams.add_hparam("ffn_layer", "conv_hidden_relu") + hparams.add_hparam("parameter_attention_key_channels", 0) + hparams.add_hparam("parameter_attention_value_channels", 0) + # All hyperparameters ending in "dropout" are automatically set to 0.0 + # when not in training mode. + hparams.add_hparam("attention_dropout", 0.0) + hparams.add_hparam("relu_dropout", 0.0) + hparams.add_hparam("residual_dropout", 0.1) + hparams.add_hparam("pos", "timing") # timing, none + hparams.add_hparam("nbr_decoder_problems", 1) + hparams.add_hparam("proximity_bias", int(False)) + # FLAGS RELATED TO MIXTURE-OF-EXPERTS + # comma-separated list of layer numbers. + # At each of these layers, we replace the ffn with a mixture of experts. + hparams.add_hparam("moe_layers_encoder", "2") + hparams.add_hparam("moe_layers_decoder", "2") + # If moe_n2 is None, then use a flat MoE with moe_n1 experts. + # If moe_n2 is an integer, then use a hierarchical MoE + # consisting of moe_n1 groups of moe_n2 experts each. + hparams.add_hparam("moe_n1", 32) + hparams.add_hparam("moe_n2", 0) + hparams.add_hparam("moe_hidden_size", 2048) + hparams.add_hparam("moe_loss_coef", 1e-2) + return hparams + + +@registry.register_hparams +def transformer_no_moe(): + """Without the mixture of experts (for comparison).""" + hparams = transformer_moe_base() + hparams.moe_layers_encoder = "" + hparams.moe_layers_decoder = "" + return hparams + + +@registry.register_hparams +def transformer_moe_1b(): + """1-billion parameter model - requires multi-gpu sync training.""" + hparams = transformer_moe_base() + hparams.moe_n1 = 128 + hparams.moe_layers_encoder = "1,3" + hparams.moe_layers_decoder = "1,3" + return hparams From 554973f1d4d8b93b466ec1b428a58e3359356519 Mon Sep 17 00:00:00 2001 From: Alexander Ku Date: Thu, 3 Aug 2017 16:42:53 -0700 Subject: [PATCH 07/12] Adding a minimum viable DNA data encoder. PiperOrigin-RevId: 164201984 --- tensor2tensor/data_generators/dna_encoder.py | 124 ++++++++++++++++++ .../data_generators/dna_encoder_test.py | 52 ++++++++ .../data_generators/gene_expression.py | 68 +--------- .../data_generators/gene_expression_test.py | 5 +- 4 files changed, 183 insertions(+), 66 deletions(-) create mode 100644 tensor2tensor/data_generators/dna_encoder.py create mode 100644 tensor2tensor/data_generators/dna_encoder_test.py diff --git a/tensor2tensor/data_generators/dna_encoder.py b/tensor2tensor/data_generators/dna_encoder.py new file mode 100644 index 000000000..0f6a8d68f --- /dev/null +++ b/tensor2tensor/data_generators/dna_encoder.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Encoders for DNA data. + +* DNAEncoder: ACTG strings to ints and back +* DelimitedDNAEncoder: for delimited subsequences +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +# Dependency imports + +from six.moves import xrange # pylint: disable=redefined-builtin +from tensor2tensor.data_generators import text_encoder + + +class DNAEncoder(text_encoder.TextEncoder): + """ACTG strings to ints and back. Optionally chunks bases into single ids. + + To use a different character set, subclass and set BASES to the char set. UNK + and PAD must not appear in the char set, but can also be reset. + + Uses 'N' as an unknown base. + """ + BASES = list("ACTG") + UNK = "N" + PAD = "0" + + def __init__(self, + chunk_size=1, + num_reserved_ids=text_encoder.NUM_RESERVED_TOKENS): + super(DNAEncoder, self).__init__(num_reserved_ids=num_reserved_ids) + # Build a vocabulary of chunks of size chunk_size + self._chunk_size = chunk_size + tokens = self._tokens() + tokens.sort() + ids = range(self._num_reserved_ids, len(tokens) + self._num_reserved_ids) + self._ids_to_tokens = dict(zip(ids, tokens)) + self._tokens_to_ids = dict(zip(tokens, ids)) + + def _tokens(self): + chunks = [] + for size in range(1, self._chunk_size + 1): + c = itertools.product(self.BASES + [self.UNK], repeat=size) + num_pad = self._chunk_size - size + padding = (self.PAD,) * num_pad + c = [el + padding for el in c] + chunks.extend(c) + return chunks + + @property + def vocab_size(self): + return len(self._ids_to_tokens) + self._num_reserved_ids + + def encode(self, s): + bases = list(s) + extra = len(bases) % self._chunk_size + if extra > 0: + pad = [self.PAD] * (self._chunk_size - extra) + bases.extend(pad) + assert (len(bases) % self._chunk_size) == 0 + num_chunks = len(bases) // self._chunk_size + ids = [] + for chunk_idx in xrange(num_chunks): + start_idx = chunk_idx * self._chunk_size + end_idx = start_idx + self._chunk_size + chunk = tuple(bases[start_idx:end_idx]) + if chunk not in self._tokens_to_ids: + raise ValueError("Unrecognized token %s" % chunk) + ids.append(self._tokens_to_ids[chunk]) + return ids + + def decode(self, ids): + bases = [] + for idx in ids: + if idx >= self._num_reserved_ids: + chunk = self._ids_to_tokens[idx] + if self.PAD in chunk: + chunk = chunk[:chunk.index(self.PAD)] + else: + chunk = [text_encoder.RESERVED_TOKENS[idx]] + bases.extend(chunk) + return "".join(bases) + + +class DelimitedDNAEncoder(DNAEncoder): + """DNAEncoder for delimiter separated subsequences. + + Uses ',' as default delimiter. + """ + + def __init__(self, delimiter=",", **kwargs): + self._delimiter = delimiter + super(DelimitedDNAEncoder, self).__init__(**kwargs) + + @property + def delimiter(self): + return self._delimiter + + def _tokens(self): + return super(DelimitedDNAEncoder, self)._tokens() + [self.delimiter] + + def encode(self, delimited_string): + ids = [] + for s in delimited_string.split(self.delimiter): + ids.extend(super(DelimitedDNAEncoder, self).encode(s)) + ids.append(self._tokens_to_ids[self.delimiter]) + return ids[:-1] diff --git a/tensor2tensor/data_generators/dna_encoder_test.py b/tensor2tensor/data_generators/dna_encoder_test.py new file mode 100644 index 000000000..a84f06442 --- /dev/null +++ b/tensor2tensor/data_generators/dna_encoder_test.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tensor2tensor.data_generators.dna_encoder.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.data_generators import dna_encoder +import tensorflow as tf + + +class DnaEncoderTest(tf.test.TestCase): + + def test_encode_decode(self): + original = 'TTCGCGGNNNAACCCAACGCCATCTATGTANNTTGAGTTGTTGAGTTAAA' + + # Encoding should be reversible for any reasonable chunk size. + for chunk_size in [1, 2, 4, 6, 8]: + encoder = dna_encoder.DNAEncoder(chunk_size=chunk_size) + encoded = encoder.encode(original) + decoded = encoder.decode(encoded) + self.assertEqual(original, decoded) + + def test_delimited_dna_encoder(self): + original = 'TTCGCGGNNN,AACCCAACGC,CATCTATGTA,NNTTGAGTTG,TTGAGTTAAA' + + # Encoding should be reversible for any reasonable chunk size. + for chunk_size in [1, 2, 4, 6, 8]: + encoder = dna_encoder.DelimitedDNAEncoder(chunk_size=chunk_size) + encoded = encoder.encode(original) + decoded = encoder.decode(encoded) + self.assertEqual(original, decoded) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensor2tensor/data_generators/gene_expression.py b/tensor2tensor/data_generators/gene_expression.py index 82c15414a..d314cec59 100644 --- a/tensor2tensor/data_generators/gene_expression.py +++ b/tensor2tensor/data_generators/gene_expression.py @@ -35,7 +35,6 @@ from __future__ import division from __future__ import print_function -import itertools import math import multiprocessing as mp import os @@ -47,6 +46,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin +from tensor2tensor.data_generators import dna_encoder from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder @@ -56,7 +56,6 @@ import tensorflow as tf MAX_CONCURRENT_PROCESSES = 10 -_bases = list("ACTG") class GeneExpressionProblem(problem.Problem): @@ -82,7 +81,7 @@ def chunk_size(self): def feature_encoders(self, data_dir): del data_dir return { - "inputs": DNAEncoder(chunk_size=self.chunk_size), + "inputs": dna_encoder.DNAEncoder(chunk_size=self.chunk_size), # TODO(rsepassi): RealEncoder? "targets": text_encoder.TextEncoder() } @@ -244,7 +243,7 @@ def dataset_generator(filepath, chunk_size=1, start_idx=None, end_idx=None): - encoder = DNAEncoder(chunk_size=chunk_size) + encoder = dna_encoder.DNAEncoder(chunk_size=chunk_size) with h5py.File(filepath, "r") as h5_file: # Get input keys from h5_file src_keys = [s % dataset for s in ["%s_in", "%s_na", "%s_out"]] @@ -278,7 +277,7 @@ def to_example_dict(encoder, inputs, mask, outputs): while idx != last_idx + 1: bases.append(encoder.UNK) last_idx += 1 - bases.append(_bases[base_id]) + bases.append(encoder.BASES[base_id]) last_idx = idx assert len(inputs) == len(bases) @@ -297,62 +296,3 @@ def to_example_dict(encoder, inputs, mask, outputs): ex_dict = dict( zip(example_keys, [input_ids, targets_mask, targets, targets_shape])) return ex_dict - - -class DNAEncoder(text_encoder.TextEncoder): - """ACTG strings to ints and back. Optionally chunks bases into single ids. - - Uses 'X' as an unknown base. - """ - UNK = "X" - PAD = "0" - - def __init__(self, - chunk_size=1, - num_reserved_ids=text_encoder.NUM_RESERVED_TOKENS): - super(DNAEncoder, self).__init__(num_reserved_ids=num_reserved_ids) - # Build a vocabulary of chunks of size chunk_size - self._chunk_size = chunk_size - chunks = [] - for size in range(1, chunk_size + 1): - c = itertools.product(_bases + [DNAEncoder.UNK], repeat=size) - num_pad = chunk_size - size - padding = (DNAEncoder.PAD,) * num_pad - c = [el + padding for el in c] - chunks.extend(c) - chunks.sort() - ids = range(self._num_reserved_ids, len(chunks) + self._num_reserved_ids) - self._ids_to_chunk = dict(zip(ids, chunks)) - self._chunks_to_ids = dict(zip(chunks, ids)) - - @property - def vocab_size(self): - return len(self._ids_to_chunk) + self._num_reserved_ids - - def encode(self, s): - bases = list(s) - pad = [DNAEncoder.PAD] * (len(bases) % self._chunk_size) - bases.extend(pad) - assert (len(bases) % self._chunk_size) == 0 - num_chunks = len(bases) // self._chunk_size - ids = [] - for chunk_idx in xrange(num_chunks): - start_idx = chunk_idx * self._chunk_size - end_idx = start_idx + self._chunk_size - chunk = tuple(bases[start_idx:end_idx]) - if chunk not in self._chunks_to_ids: - raise ValueError("Unrecognized chunk %s" % chunk) - ids.append(self._chunks_to_ids[chunk]) - return ids - - def decode(self, ids): - bases = [] - for idx in ids: - if idx >= self._num_reserved_ids: - chunk = self._ids_to_chunk[idx] - if DNAEncoder.PAD in chunk: - chunk = chunk[:chunk.index(DNAEncoder.PAD)] - else: - chunk = [text_encoder.RESERVED_TOKENS[idx]] - bases.extend(chunk) - return "".join(bases) diff --git a/tensor2tensor/data_generators/gene_expression_test.py b/tensor2tensor/data_generators/gene_expression_test.py index 2d7bbe832..797170070 100644 --- a/tensor2tensor/data_generators/gene_expression_test.py +++ b/tensor2tensor/data_generators/gene_expression_test.py @@ -22,6 +22,7 @@ import numpy as np +from tensor2tensor.data_generators import dna_encoder from tensor2tensor.data_generators import gene_expression import tensorflow as tf @@ -40,8 +41,8 @@ def _oneHotBases(self, bases): return np.array(one_hots) def testRecordToExample(self): - encoder = gene_expression.DNAEncoder(chunk_size=2) - raw_inputs = ["A", "C", "G", "X", "C", "T"] + encoder = dna_encoder.DNAEncoder(chunk_size=2) + raw_inputs = ["A", "C", "G", "N", "C", "T"] # Put in numpy arrays in the same format as in the h5 file inputs = self._oneHotBases(raw_inputs) From 7efdbeebe777dfcbf005e335c620db4f810ecd16 Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Thu, 3 Aug 2017 17:26:41 -0700 Subject: [PATCH 08/12] Small transformer models (reasonable translations in 1h on 1080). PiperOrigin-RevId: 164207044 --- tensor2tensor/data_generators/all_problems.py | 1 + tensor2tensor/data_generators/cipher.py | 251 ++++++++++++++++++ tensor2tensor/layers/common_layers.py | 7 + tensor2tensor/models/transformer.py | 15 +- 4 files changed, 272 insertions(+), 2 deletions(-) create mode 100644 tensor2tensor/data_generators/cipher.py diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index af2030d89..ca6dccfda 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -22,6 +22,7 @@ from tensor2tensor.data_generators import algorithmic from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import audio +from tensor2tensor.data_generators import cipher from tensor2tensor.data_generators import desc2code from tensor2tensor.data_generators import image from tensor2tensor.data_generators import lm1b diff --git a/tensor2tensor/data_generators/cipher.py b/tensor2tensor/data_generators/cipher.py new file mode 100644 index 000000000..3a743337a --- /dev/null +++ b/tensor2tensor/data_generators/cipher.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cipher data generators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import deque + +# Dependency imports + +import numpy as np + +from tensor2tensor.data_generators import algorithmic +from tensor2tensor.utils import registry + + +@registry.register_problem +class CipherShift5(algorithmic.AlgorithmicProblem): + """Shift cipher.""" + + @property + def num_symbols(self): + return 5 + + @property + def distribution(self): + return [0.4, 0.3, 0.2, 0.08, 0.02] + + @property + def shift(self): + return 1 + + @property + def train_generator(self): + """Generator; takes 3 args: nbr_symbols, max_length, nbr_cases.""" + + def _gen(nbr_symbols, max_length, nbr_cases): + plain_vocab = range(nbr_symbols) + indices = generate_plaintext_random(plain_vocab, self.distribution, + nbr_cases, max_length) + codes = encipher_shift(indices, plain_vocab, self.shift) + + for plain, code in zip(indices, codes): + yield { + "X": plain, + "Y": code, + } + + return _gen + + @property + def train_length(self): + return 100 + + @property + def dev_length(self): + return self.train_length + + +@registry.register_problem +class CipherVigenere5(algorithmic.AlgorithmicProblem): + """Vinegre cipher.""" + + @property + def num_symbols(self): + return 5 + + @property + def distribution(self): + return [0.4, 0.3, 0.2, 0.08, 0.02] + + @property + def key(self): + return [1, 3] + + @property + def train_generator(self): + """Generator; takes 3 args: nbr_symbols, max_length, nbr_cases.""" + + def _gen(nbr_symbols, max_length, nbr_cases): + plain_vocab = range(nbr_symbols) + indices = generate_plaintext_random(plain_vocab, self.distribution, + nbr_cases, max_length) + codes = encipher_vigenere(indices, plain_vocab, self.key) + + for plain, code in zip(indices, codes): + yield { + "X": plain, + "Y": code, + } + + return _gen + + @property + def train_length(self): + return 200 + + @property + def dev_length(self): + return self.train_length + + +@registry.register_problem +class CipherShift200(CipherShift5): + """Shift cipher.""" + + @property + def num_symbols(self): + return 200 + + @property + def distribution(self): + vals = range(self.num_symbols) + val_sum = sum(vals) + return [v / val_sum for v in vals] + + +@registry.register_problem +class CipherVigenere200(CipherVigenere5): + """Vinegre cipher.""" + + @property + def num_symbols(self): + return 200 + + @property + def distribution(self): + vals = range(self.num_symbols) + val_sum = sum(vals) + return [v / val_sum for v in vals] + + @property + def key(self): + return [1, 3] + + +class Layer(object): + """A single layer for shift.""" + + def __init__(self, vocab, shift): + """Initialize shift layer. + + Args: + vocab: (list of String) the vocabulary + shift: (Integer) the amount of shift apply to the alphabet. + Positive number implies shift to the right, negative number + implies shift to the left. + """ + self.shift = shift + alphabet = vocab + shifted_alphabet = deque(alphabet) + shifted_alphabet.rotate(shift) + self.encrypt = dict(zip(alphabet, list(shifted_alphabet))) + self.decrypt = dict(zip(list(shifted_alphabet), alphabet)) + + def encrypt_character(self, character): + return self.encrypt[character] + + def decrypt_character(self, character): + return self.decrypt[character] + + +def generate_plaintext_random(plain_vocab, distribution, train_samples, + length): + """Generates samples of text from the provided vocabulary. + + Args: + plain_vocab: vocabulary. + distribution: distribution. + train_samples: samples for training. + length: length. + + Returns: + train_indices (np.array of Integers): random integers for training. + shape = [num_samples, length] + test_indices (np.array of Integers): random integers for testing. + shape = [num_samples, length] + plain_vocab (list of Integers): unique vocabularies. + """ + if distribution is not None: + assert len(distribution) == len(plain_vocab) + + train_indices = np.random.choice( + range(len(plain_vocab)), (train_samples, length), p=distribution) + + return train_indices + + +def encipher_shift(plaintext, plain_vocab, shift): + """Encrypt plain text with a single shift layer. + + Args: + plaintext (list of list of Strings): a list of plain text to encrypt. + plain_vocab (list of Integer): unique vocabularies being used. + shift (Integer): number of shift, shift to the right if shift is positive. + Returns: + ciphertext (list of Strings): encrypted plain text. + """ + ciphertext = [] + cipher = Layer(plain_vocab, shift) + + for _, sentence in enumerate(plaintext): + cipher_sentence = [] + for _, character in enumerate(sentence): + encrypted_char = cipher.encrypt_character(character) + cipher_sentence.append(encrypted_char) + ciphertext.append(cipher_sentence) + + return ciphertext + + +def encipher_vigenere(plaintext, plain_vocab, key): + """Encrypt plain text with given key. + + Args: + plaintext (list of list of Strings): a list of plain text to encrypt. + plain_vocab (list of Integer): unique vocabularies being used. + key (list of Integer): key to encrypt cipher using Vigenere table. + + Returns: + ciphertext (list of Strings): encrypted plain text. + """ + ciphertext = [] + # generate Vigenere table + layers = [] + for i in range(len(plain_vocab)): + layers.append(Layer(plain_vocab, i)) + + for i, sentence in enumerate(plaintext): + cipher_sentence = [] + for j, character in enumerate(sentence): + key_idx = key[j % len(key)] + encrypted_char = layers[key_idx].encrypt_character(character) + cipher_sentence.append(encrypted_char) + ciphertext.append(cipher_sentence) + + return ciphertext diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index 8a58cd065..ea18322e4 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -59,6 +59,13 @@ def inverse_exp_decay(max_step, min_value=0.01): return inv_base**tf.maximum(float(max_step) - step, 0.0) +def inverse_lin_decay(max_step, min_value=0.01): + """Inverse-decay linearly from 0.01 to 1.0 reached at max_step.""" + step = tf.to_float(tf.contrib.framework.get_global_step()) + progress = tf.minimum(step / float(max_step), 1.0) + return progress * (1.0 - min_value) + min_value + + def shakeshake2_py(x, y, equal=False, individual=False): """The shake-shake sum of 2 tensors, python version.""" if equal: diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 1add44115..c9c87da07 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -386,8 +386,19 @@ def transformer_parsing_ice(): @registry.register_hparams def transformer_tiny(): hparams = transformer_base() - hparams.hidden_size = 64 - hparams.filter_size = 128 + hparams.num_hidden_layers = 2 + hparams.hidden_size = 128 + hparams.filter_size = 512 + hparams.num_heads = 4 + return hparams + + +@registry.register_hparams +def transformer_small(): + hparams = transformer_base() + hparams.num_hidden_layers = 2 + hparams.hidden_size = 256 + hparams.filter_size = 1024 hparams.num_heads = 4 return hparams From e8ae5894e40b8c18a37601762eefa51484bf4953 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Fri, 4 Aug 2017 13:52:42 -0700 Subject: [PATCH 09/12] Support for dictionary losses in model_fn_body to be consistent with model_fn_body_sharded. Also updated inline doc. PiperOrigin-RevId: 164305140 --- tensor2tensor/utils/t2t_model.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 7cb484bc8..3af4f10c1 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -469,7 +469,10 @@ def model_fn_body_sharded(self, sharded_features): _with_timing(self.model_fn_body, "model_fn_body"), datashard_to_features) if isinstance(output, tuple): - loss = {"extra": tf.reduce_mean(output[1])} + if isinstance(output[1], dict): + loss = output[1] + else: + loss = {"extra": tf.reduce_mean(output[1])} output = output[0] else: loss = {"extra": 0.0} @@ -483,10 +486,12 @@ def model_fn_body(self, features): Args: features: A dictionary of key to Tensor. Each Tensor has shape - `[batch_size, ?, ?, hidden_size]`. + [batch_size, ?, ?, hidden_size]. Returns: - a `Tensor` of logits with shape `[batch_size, O, P, body_output_size]`. + output: tensor of logits with shape [batch_size, O, P, body_output_size. + losses: either single loss as a scalar, a list, a tensor (to be averaged) + or a dictionary of losses. """ raise NotImplementedError("Abstract Method") From a0bd0177bf766c953041b7451398ab1791adb1e5 Mon Sep 17 00:00:00 2001 From: Ashish Vaswani Date: Fri, 4 Aug 2017 14:41:57 -0700 Subject: [PATCH 10/12] Reverted back to the previous masked_local_attention_1d because the current one was giving 0 losses indicating that it was peeking into the future. The way the attention bias was being added also seemed wrong. Renamed unmasked_local_attention_1d to local_attention_1d. The user can specify local_attention_1d if they want to look left and right of the query block. PiperOrigin-RevId: 164312109 --- tensor2tensor/layers/common_attention.py | 132 ++++++++---------- tensor2tensor/layers/common_attention_test.py | 4 +- 2 files changed, 57 insertions(+), 79 deletions(-) diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index e343dba0a..a43afec47 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -361,122 +361,100 @@ def dot_product_attention(q, return tf.matmul(weights, v) -def masked_local_attention_1d(q, - k, - v, - block_length=128, - look_right=True, - use_whole_block=False, - name=None): - """Attention to the source position and a neigborhood around it. - - The sequence is divided into blocks of length block_size. Attention for a - given query position can only see memory positions within a certain number - of positions before and behind it. - - - If look_right is True then each query will attend to block_length//2 - positions either side, otherwise it will attend to block_length previous - positions. +def masked_local_attention_1d( + q, k, v, block_length=128, name=None): + """Attention to the source position and a neigborhood to the left of it. + + The sequence is divided into blocks of length block_size. + Attention for a given query position can only see memory positions + less than or equal to the query position, in the corresponding block + and the previous block. - If use_whole_block is True then no mask will be applied to the local blocks - meaning the full blocks are used (if look_right is True then the elements to - the right of the current position are still masked out). This allows to - attend to more elements without additional overhead, but means we have - inconsistent window positions and sizes. + If mask_right is True, then a target position cannot see greater source + positions. Args: - q: a Tensor with shape [batch, heads, length_q, depth_k] - k: a Tensor with shape [batch, heads, length_kv, depth_k] - v: a Tensor with shape [batch, heads, length_kv, depth_v] + q: a Tensor with shape [batch, heads, length, depth_k] + k: a Tensor with shape [batch, heads, length, depth_k] + v: a Tensor with shape [batch, heads, length, depth_v] block_length: an integer - look_right: a bool - use_whole_block: a bool name: an optional string Returns: a Tensor of shape [batch, heads, length, depth_v] """ - with tf.variable_scope( - name, default_name="local_attention_1d", values=[q, k, v]): + with tf.variable_scope(name, default_name="local_attention_1d", + values=[q, k, v]): v_shape = v.get_shape() batch = tf.shape(q)[0] heads = tf.shape(q)[1] length = tf.shape(q)[2] + # If (length < 2 * block_length), then we use only one block. + block_length = tf.where(tf.less(length, block_length * 2), + length, block_length) depth_k = tf.shape(q)[3] depth_v = tf.shape(v)[3] original_length = length - - # If (length < block_length), then we use only one block. - block_length = tf.where(tf.less(length, block_length), length, block_length) - # Pad to desired length. padding_size = tf.mod(-length, block_length) length += padding_size - num_blocks = tf.div(length, block_length) padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]] q = tf.pad(q, padding) + k = tf.pad(k, padding) + v = tf.pad(v, padding) + num_blocks = tf.div(length, block_length) - if not look_right: - # Add extra padding so we son't have to do an initial query block. - extra_padding = [[0, 0], [0, 0], [block_length, padding_size], [0, 0]] - else: - # We shift everything over by half a block so query is in center. - pad_right = block_length // 2 - pad_left = block_length - pad_right - extra_padding = [[0, 0], [0, 0], [pad_left, padding_size + pad_right], - [0, 0]] - k = tf.pad(k, extra_padding) - v = tf.pad(v, extra_padding) - - # Reshape into blocks. + # compute attention for the first query block. + first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1]) + first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1]) + first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1]) + first_output = dot_product_attention( + first_q, first_k, first_v, attention_bias_lower_triangle(block_length), + name="fist_block") + + # compute attention for all subsequent query blocks. q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k]) - k = tf.reshape(k, [batch, heads, num_blocks + 1, block_length, depth_k]) - v = tf.reshape(v, [batch, heads, num_blocks + 1, block_length, depth_v]) + k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k]) + v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v]) - # Get local blocks by slicing. def local(x): """Create a local version of the keys or values.""" - prev_block = tf.slice(x, [0, 0, 0, 0, 0], [-1, -1, num_blocks, -1, -1]) - cur_block = tf.slice(x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) + prev_block = tf.slice( + x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1]) + cur_block = tf.slice( + x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) return tf.concat([prev_block, cur_block], 3) - local_k = local(k) local_v = local(v) - local_length = tf.shape(local_k)[3] + tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) - # [batch, heads, num_blocks, block_length, local_length] - attention = tf.matmul(q, local_k, transpose_b=True) - attention = tf.nn.softmax(attention) - - # Get local mask - if not use_whole_block: - good_part = tf.matrix_band_part( - tf.ones([block_length, local_length]), 0, tf.to_int64(block_length)) - elif not look_right: - good_part = tf.matrix_band_part( - tf.ones([block_length, local_length]), -1, tf.to_int64(block_length)) - else: - good_part = tf.ones([block_length, local_length]) + local_length = tf.shape(local_k)[3] - attention *= tf.reshape(good_part, [1, 1, 1, block_length, local_length]) + # [batch, heads, num_blocks - 1, block_length, local_length] + attention = tf.matmul(tail_q, local_k, transpose_b=True) + # make sure source_pos <= target_pos + good_part = tf.matrix_band_part( + tf.ones([block_length, local_length]), -1, tf.to_int64(block_length)) + mask = (1.0 - good_part) * -1e9 + attention += tf.reshape(mask, [1, 1, 1, block_length, local_length]) + attention = tf.nn.softmax(attention) # TODO(noam): figure out how to show a summary for the remaining blocks. # The naive way currently causes errors due to empty tensors. + # output: [batch, heads, num_blocks-1, block_length, depth_v] output = tf.matmul(attention, local_v) output = tf.reshape(output, [batch, heads, -1, depth_v]) - - # Remove added padding + output = tf.concat([first_output, output], axis=2) output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) output.set_shape(v_shape) return output -def unmasked_local_attention_1d(q, - k, - v, - block_length=128, - filter_width=100, - name=None): +def local_attention_1d(q, + k, + v, + block_length=128, + filter_width=100, + name=None): """strided block local self-attention. Args: @@ -644,7 +622,7 @@ def multihead_attention(query_antecedent, x = masked_local_attention_1d(q, k, v, block_length=block_length) else: assert attention_type == "local_unmasked" - x = unmasked_local_attention_1d( + x = local_attention_1d( q, k, v, block_length=block_length, filter_width=block_width) x = combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py index 61855b876..e846c2002 100644 --- a/tensor2tensor/layers/common_attention_test.py +++ b/tensor2tensor/layers/common_attention_test.py @@ -65,7 +65,7 @@ def testLocalUnmaskedAttention(self): x = np.random.rand(5, 4, 25, 16) y = np.random.rand(5, 4, 25, 16) with self.test_session() as session: - a = common_attention.unmasked_local_attention_1d( + a = common_attention.local_attention_1d( tf.constant(x, dtype=tf.float32), tf.constant(y, dtype=tf.float32), tf.constant(y, dtype=tf.float32), @@ -79,7 +79,7 @@ def testLocalUnmaskedAttentionMatchingBlockLength(self): x = np.random.rand(5, 4, 25, 16) y = np.random.rand(5, 4, 25, 16) with self.test_session() as session: - a = common_attention.unmasked_local_attention_1d( + a = common_attention.local_attention_1d( tf.constant(x, dtype=tf.float32), tf.constant(y, dtype=tf.float32), tf.constant(y, dtype=tf.float32), From f25af0f7eebde41ec310e3dce5759a8969d5e214 Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Fri, 4 Aug 2017 15:04:18 -0700 Subject: [PATCH 11/12] Share desc2code source vocab with translation, baseline to play with VAE. PiperOrigin-RevId: 164315503 --- tensor2tensor/data_generators/desc2code.py | 68 ++++---- tensor2tensor/models/models.py | 1 + tensor2tensor/models/transformer_vae.py | 185 +++++++++++++++++++++ 3 files changed, 218 insertions(+), 36 deletions(-) create mode 100644 tensor2tensor/models/transformer_vae.py diff --git a/tensor2tensor/data_generators/desc2code.py b/tensor2tensor/data_generators/desc2code.py index 52513e63c..98c93aacd 100644 --- a/tensor2tensor/data_generators/desc2code.py +++ b/tensor2tensor/data_generators/desc2code.py @@ -44,8 +44,8 @@ _DESC_DIR_NAME = "description" _CODE_PY_DIR_NAME = "solutions_python" -_VOCAB_EN_FILENAME = "vocab_desc2code_tok_en" -_VOCAB_PY_FILENAME = "vocab_desc2code_tok_py" +_VOCAB_EN_FILENAME = "vocab.endefr" +_VOCAB_PY_FILENAME = "vocab.py" # Struct containing a coding problem (contains the paths to the descriptions # and code files) @@ -61,21 +61,43 @@ def is_character_level(self): @property def num_shards(self): - return 100 + return 10 @property def use_subword_tokenizer(self): return True + @property + def input_vocab_size(self): + return 2**15 # 32k + + @property + def target_vocab_size(self): + return 2**12 # 4k + + @property + def vocab_input_filename(self): + return "{}.{}".format(_VOCAB_EN_FILENAME, self.input_vocab_size) + + @property + def vocab_target_filename(self): + return "{}.{}".format(_VOCAB_PY_FILENAME, self.target_vocab_size) + + def feature_encoders(self, data_dir): + source_vocab_filename = os.path.join(data_dir, self.vocab_input_filename) + target_vocab_filename = os.path.join(data_dir, self.vocab_target_filename) + source_token = text_encoder.SubwordTextEncoder(source_vocab_filename) + target_token = text_encoder.SubwordTextEncoder(target_vocab_filename) + return { + "inputs": source_token, + "targets": target_token, + } + @registry.register_problem("desc2code_py") class Desc2CodePyProblem(Desc2CodeProblem): """Description2Code for python problem.""" - @property - def targeted_vocab_size(self): - return 2**13 # 8192 - @property def input_space_id(self): return problem.SpaceID.EN_TOK @@ -84,14 +106,6 @@ def input_space_id(self): def target_space_id(self): return problem.SpaceID.PY_TOK - @property - def vocab_input_filename(self): - return "{}.{}".format(_VOCAB_EN_FILENAME, self.targeted_vocab_size) - - @property - def vocab_target_filename(self): - return "{}.{}".format(_VOCAB_PY_FILENAME, self.targeted_vocab_size) - def train_generator(self, data_dir, tmp_dir, train): # Called twice: for train and test @@ -135,27 +149,19 @@ def generator_samples_content(get_source, get_target): elif sample.code_files: # Only take the source if a target exists yield source, target - def generator_source(): - for source, _ in generator_samples_content(True, False): - yield source.strip() - def generator_target(): for _, target in generator_samples_content(False, True): yield target.strip() # Generate vocab for both source and target - source_vocab = generator_utils.get_or_generate_vocab_inner( - data_dir=data_dir, - vocab_filename=self.vocab_input_filename, - vocab_size=self.targeted_vocab_size, - generator_fn=generator_source, - ) + source_vocab = generator_utils.get_or_generate_vocab( + data_dir, tmp_dir, self.vocab_input_filename, self.input_vocab_size) target_vocab = generator_utils.get_or_generate_vocab_inner( data_dir=data_dir, vocab_filename=self.vocab_target_filename, - vocab_size=self.targeted_vocab_size, + vocab_size=self.target_vocab_size, generator_fn=generator_target, ) @@ -169,16 +175,6 @@ def generator_target(): "targets": target_ints, } - def feature_encoders(self, data_dir): - source_vocab_filename = os.path.join(data_dir, self.vocab_input_filename) - target_vocab_filename = os.path.join(data_dir, self.vocab_target_filename) - source_token = text_encoder.SubwordTextEncoder(source_vocab_filename) - target_token = text_encoder.SubwordTextEncoder(target_vocab_filename) - return { - "inputs": source_token, - "targets": target_token, - } - # Utils functions diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index 963975780..4b1355dba 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -38,5 +38,6 @@ from tensor2tensor.models import transformer from tensor2tensor.models import transformer_alternative from tensor2tensor.models import transformer_moe +from tensor2tensor.models import transformer_vae from tensor2tensor.models import xception # pylint: enable=unused-import diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py new file mode 100644 index 000000000..31de7bd5f --- /dev/null +++ b/tensor2tensor/models/transformer_vae.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""VAE Transformer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensor2tensor.layers import common_layers +from tensor2tensor.models import transformer +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +def decompress(source, hparams, name): + """Decompression function.""" + with tf.variable_scope(name): + shape = tf.shape(source) + thicker = common_layers.conv_block( + source, hparams.hidden_size * 2, [((1, 1), (1, 1))], + name="decompress_conv") + return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size]) + + +def vae(x, hparams, name): + with tf.variable_scope(name): + mu = tf.layers.dense(x, hparams.z_size, name="mu") + log_sigma = tf.layers.dense(x, hparams.z_size, name="log_sigma") + shape = tf.shape(x) + epsilon = tf.random_normal([shape[0], shape[1], 1, hparams.z_size]) + z = mu + tf.exp(log_sigma / 2) * epsilon + dense = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense") + kl = 0.5 * tf.reduce_mean( + tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma, axis=-1) + return dense, tf.reduce_mean(kl) + + +def compress_vae(inputs, hparams, name): + """Compress, then VAE.""" + with tf.variable_scope(name): + # Run compression by strided convs. + cur = tf.expand_dims(inputs, axis=2) + for i in xrange(hparams.num_compress_steps): + cur = common_layers.conv_block( + cur, hparams.hidden_size, [((1, 1), (2, 1))], + strides=(2, 1), name="compress_%d" % i) + + # Convolve and ReLu to get state. + cur = common_layers.conv_block( + cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv") + + cur, kl_loss = vae(cur, hparams, name="vae") + return cur, kl_loss + + +def vae_transformer_internal(inputs, targets, target_space, hparams): + """VAE Transformer, main step used for training.""" + with tf.variable_scope("vae_transformer"): + is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN + # Prepare inputs, targets, and k. + inputs = common_layers.flatten4d3d(inputs) + targets = common_layers.flatten4d3d(targets) + k = 2**hparams.num_compress_steps + _, targets = common_layers.pad_to_same_length( + inputs, targets, final_length_divisible_by=k) + + # Transformer preparations and encoder. + (encoder_input, encoder_self_attention_bias, + encoder_decoder_attention_bias) = transformer.transformer_prepare_encoder( + inputs, target_space, hparams) + residual_fn = transformer.get_residual_fn(hparams) + encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) + encoder_output = transformer.transformer_encoder( + encoder_input, residual_fn, encoder_self_attention_bias, hparams) + + def get_decoder_autoregressive(): + """Decoder input for autoregressive computation.""" + (a, b) = transformer.transformer_prepare_decoder(targets, hparams) + return (a, b, tf.constant(0.0)) + + # 10% of the time we compress all-zeros, as will be at decoding start. + prob_targets = 0.9 if is_training else 1.0 + to_compress = tf.cond(tf.less(tf.random_uniform([]), prob_targets), + lambda: targets, lambda: tf.zeros_like(targets)) + z, kl_loss = compress_vae(to_compress, hparams, "vae") + # Decompress. + for i in xrange(hparams.num_compress_steps): + j = hparams.num_hidden_layers - i - 1 + z = decompress(z, hparams, "decompress_%d" % j) + + def get_decoder_from_vae(): + """Decoder input computed by VAE.""" + # Return decoder stuff. + (a, b) = transformer.transformer_prepare_decoder( + tf.squeeze(z, axis=2), hparams) + return (a, b, kl_loss) + + # Randomize decoder inputs.. + prob_do_vae = common_layers.inverse_exp_decay(40000) * 0.7 + step = tf.to_float(tf.contrib.framework.get_global_step()) + if not is_training: + prob_do_vae = tf.cond(tf.less(step, 40000.0), lambda: tf.constant(0.0), + lambda: tf.constant(1.0)) + (decoder_input, decoder_self_attention_bias, kl_loss2) = tf.cond( + tf.less(tf.random_uniform([]), prob_do_vae), + get_decoder_from_vae, get_decoder_autoregressive) + + # Transformer decoder. + decoder_output = transformer.transformer_decoder( + decoder_input, encoder_output, residual_fn, decoder_self_attention_bias, + encoder_decoder_attention_bias, hparams) + decoder_output = tf.expand_dims(decoder_output, 2) + + cond_self = tf.cond(tf.less(step, 30000.0), lambda: tf.constant(1.0), + lambda: tf.constant(0.0)) + prob_self = 0.4 if is_training else cond_self + (ret, kl_loss) = tf.cond(tf.less(tf.random_uniform([]), prob_self), + lambda: (z, kl_loss), + lambda: (decoder_output, kl_loss2)) + + kl_loss *= common_layers.inverse_exp_decay(50000) * 2.0 + return ret, kl_loss + + +@registry.register_model +class TransformerVAE(t2t_model.T2TModel): + + def model_fn_body(self, features): + return vae_transformer_internal( + features["inputs"], features["targets"], features["target_space_id"], + self._hparams) + + def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, + last_position_only=False, alpha=0.0): + """A inference method, see T2TModel.""" + if not features: + features = {} + inputs_old = None + if "inputs" in features and len(features["inputs"].shape) < 4: + inputs_old = features["inputs"] + features["inputs"] = tf.expand_dims(features["inputs"], 2) + + # Create an initial targets tensor. + if "partial_targets" in features: + initial_output = tf.convert_to_tensor(features["partial_targets"]) + else: + batch_size = tf.shape(features["inputs"])[0] + initial_output = tf.zeros((batch_size, 1, 1, 1), dtype=tf.int64) + + features["targets"] = initial_output + sharded_logits, _ = self.model_fn( + features, False, last_position_only=last_position_only) + sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) + samples = tf.concat(sharded_samples, 0) + if inputs_old is not None: # Restore to not confuse Estimator. + features["inputs"] = inputs_old + return samples + + +@registry.register_hparams +def transformer_vae_small(): + """Set of hyperparameters.""" + hparams = transformer.transformer_small() + hparams.add_hparam("z_size", 128) + hparams.add_hparam("num_compress_steps", 4) + return hparams From 932e5c2cbd46bc579917b7785b544012ba41ea2a Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Fri, 4 Aug 2017 15:50:54 -0700 Subject: [PATCH 12/12] v1.1.6 PiperOrigin-RevId: 164321289 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 38b2fcc48..c62b3409c 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.1.5', + version='1.1.6', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com',