From b09291df5e8741c373da0b51ab5a4ad66c9a49e4 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Fri, 23 Jun 2017 11:14:41 -0700 Subject: [PATCH 1/5] Bump version to 1.0.6 PiperOrigin-RevId: 159970178 --- .../data_generators/generator_utils.py | 2 +- tensor2tensor/data_generators/image.py | 3 +- tensor2tensor/data_generators/text_encoder.py | 87 +++++++------------ tensor2tensor/data_generators/tokenizer.py | 24 +++-- 4 files changed, 50 insertions(+), 66 deletions(-) mode change 100755 => 100644 tensor2tensor/data_generators/text_encoder.py mode change 100755 => 100644 tensor2tensor/data_generators/tokenizer.py diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 11788df45..35e61d7cc 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -22,12 +22,12 @@ import io import os import tarfile +import urllib # Dependency imports import six from six.moves import xrange # pylint: disable=redefined-builtin -import six.moves.urllib_request from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder from tensor2tensor.data_generators.tokenizer import Tokenizer diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index ee0ad26d5..55b5f2fc7 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import cPickle import gzip import io import json @@ -31,8 +32,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin -from six.moves import cPickle - from tensor2tensor.data_generators import generator_utils import tensorflow as tf diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py old mode 100755 new mode 100644 index 74d2b73cb..b170013ea --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -27,7 +27,6 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin -from collections import defaultdict from tensor2tensor.data_generators import tokenizer import tensorflow as tf @@ -36,10 +35,7 @@ PAD = '' EOS = '' RESERVED_TOKENS = [PAD, EOS] -if six.PY2: - RESERVED_TOKENS_BYTES = RESERVED_TOKENS -else: - RESERVED_TOKENS_BYTES = [bytes(PAD, 'ascii'), bytes(EOS, 'ascii')] + class TextEncoder(object): """Base class for converting from ints to/from human readable strings.""" @@ -91,25 +87,17 @@ class ByteTextEncoder(TextEncoder): """Encodes each byte to an id. For 8-bit strings only.""" def encode(self, s): - numres = self._num_reserved_ids - if six.PY2: - return [ord(c) + numres for c in s] - # Python3: explicitly convert to UTF-8 - return [c + numres for c in s.encode("utf-8")] + return [ord(c) + self._num_reserved_ids for c in s] def decode(self, ids): - numres = self._num_reserved_ids decoded_ids = [] - int2byte = six.int2byte for id_ in ids: - if 0 <= id_ < numres: - decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) + if 0 <= id_ < self._num_reserved_ids: + decoded_ids.append(RESERVED_TOKENS[int(id_)]) else: - decoded_ids.append(int2byte(id_ - numres)) - if six.PY2: - return ''.join(decoded_ids) - # Python3: join byte arrays and then decode string - return b''.join(decoded_ids).decode("utf-8") + decoded_ids.append(chr(id_)) + + return ''.join(decoded_ids) @property def vocab_size(self): @@ -123,16 +111,20 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2): """Initialize from a file, one token per line.""" super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) self._reverse = reverse - self._load_vocab_from_file(vocab_filename) + if vocab_filename is not None: + self._load_vocab_from_file(vocab_filename) def encode(self, sentence): """Converts a space-separated string of tokens to a list of ids.""" ret = [self._token_to_id[tok] for tok in sentence.strip().split()] - return ret[::-1] if self._reverse else ret + if self._reverse: + ret = ret[::-1] + return ret def decode(self, ids): - seq = reversed(ids) if self._reverse else ids - return ' '.join([self._safe_id_to_token(i) for i in seq]) + if self._reverse: + ids = ids[::-1] + return ' '.join([self._safe_id_to_token(i) for i in ids]) @property def vocab_size(self): @@ -251,22 +243,15 @@ def _escaped_token_to_subtokens(self, escaped_token): """ ret = [] pos = 0 - lesc = len(escaped_token) - while pos < lesc: - end = lesc - while end > pos: + while pos < len(escaped_token): + end = len(escaped_token) + while True: subtoken = self._subtoken_string_to_id.get(escaped_token[pos:end], -1) if subtoken != -1: break end -= 1 ret.append(subtoken) - if end > pos: - pos = end - else: - # This kinda should not happen, but it does. Cop out by skipping the - # nonexistent subtoken from the returned list. - # print("Unable to find subtoken in string '{0}'".format(escaped_token)) - pos += 1 + pos = end return ret @classmethod @@ -337,13 +322,13 @@ def build_from_token_counts(self, # then count the resulting potential subtokens, keeping the ones # with high enough counts for our new vocabulary. for i in xrange(num_iterations): - counts = defaultdict(int) + counts = {} for token, count in six.iteritems(token_counts): escaped_token = self._escape_token(token) # we will count all tails of the escaped_token, starting from boundaries # determined by our current segmentation. if i == 0: - starts = xrange(len(escaped_token)) + starts = list(range(len(escaped_token))) else: subtokens = self._escaped_token_to_subtokens(escaped_token) pos = 0 @@ -352,33 +337,31 @@ def build_from_token_counts(self, starts.append(pos) pos += len(self.subtoken_to_subtoken_string(subtoken)) for start in starts: - for end in xrange(start + 1, len(escaped_token)): + for end in xrange(start + 1, len(escaped_token) + 1): subtoken_string = escaped_token[start:end] - counts[subtoken_string] += count + counts[subtoken_string] = counts.get(subtoken_string, 0) + count # array of lists of candidate subtoken strings, by length len_to_subtoken_strings = [] for subtoken_string, count in six.iteritems(counts): - lsub = len(subtoken_string) - # all subtoken strings of length 1 are included regardless of count - if count < min_count and lsub != 1: + if count < min_count or len(subtoken_string) <= 1: continue - while len(len_to_subtoken_strings) <= lsub: + while len(len_to_subtoken_strings) <= len(subtoken_string): len_to_subtoken_strings.append([]) - len_to_subtoken_strings[lsub].append(subtoken_string) + len_to_subtoken_strings[len(subtoken_string)].append(subtoken_string) new_subtoken_strings = [] # consider the candidates longest to shortest, so that if we accept # a longer subtoken string, we can decrement the counts of its prefixes. for subtoken_strings in len_to_subtoken_strings[::-1]: for subtoken_string in subtoken_strings: count = counts[subtoken_string] - if count < min_count and len(subtoken_string) != 1: - # subtoken strings of length 1 are included regardless of count + if count < min_count: continue new_subtoken_strings.append((-count, subtoken_string)) for l in xrange(1, len(subtoken_string)): counts[subtoken_string[:l]] -= count - # Make sure to include the underscore as a subtoken string - new_subtoken_strings.append((0, '_')) + # make sure we have all single characters. + new_subtoken_strings.extend([(-counts.get(chr(i), 0), chr(i)) + for i in xrange(2**8)]) new_subtoken_strings.sort() self._init_from_list([''] * self._num_reserved_ids + [p[1] for p in new_subtoken_strings]) @@ -407,19 +390,13 @@ def _load_from_file(self, filename): subtoken_strings = [] with tf.gfile.Open(filename) as f: for line in f: - if six.PY2: - subtoken_strings.append(line.strip()[1:-1].decode('string-escape')) - else: - subtoken_strings.append(line.strip()[1:-1]) + subtoken_strings.append(line.strip()[1:-1].decode('string-escape')) self._init_from_list(subtoken_strings) def _store_to_file(self, filename): with tf.gfile.Open(filename, 'w') as f: for subtoken_string in self._all_subtoken_strings: - if six.PY2: - f.write('\'' + subtoken_string.encode('string-escape') + '\'\n') - else: - f.write('\'' + subtoken_string + '\'\n') + f.write('\'' + subtoken_string.encode('string-escape') + '\'\n') def _escape_token(self, token): r"""Translate '\'->'\\' and '_'->'\u', then append '_'. diff --git a/tensor2tensor/data_generators/tokenizer.py b/tensor2tensor/data_generators/tokenizer.py old mode 100755 new mode 100644 index 09b60ff1f..15b199907 --- a/tensor2tensor/data_generators/tokenizer.py +++ b/tensor2tensor/data_generators/tokenizer.py @@ -45,21 +45,29 @@ from __future__ import division from __future__ import print_function +import array import string # Dependency imports from six.moves import xrange # pylint: disable=redefined-builtin -from collections import defaultdict + class Tokenizer(object): """Vocab for breaking words into wordpieces. """ - _SEPARATOR_CHAR_SET = set(string.punctuation + string.whitespace) - def __init__(self): - self.token_counts = defaultdict(int) + self._separator_chars = string.punctuation + string.whitespace + self._separator_char_mask = array.array( + "l", [chr(i) in self._separator_chars for i in xrange(256)]) + self.token_counts = dict() + + def _increment_token_count(self, token): + if token in self.token_counts: + self.token_counts[token] += 1 + else: + self.token_counts[token] = 1 def encode(self, raw_text): """Encode a raw string as a list of tokens. @@ -79,11 +87,11 @@ def encode(self, raw_text): token = raw_text[token_start:pos] if token != " " or token_start == 0: ret.append(token) - self.token_counts[token] += 1 + self._increment_token_count(token) token_start = pos final_token = raw_text[token_start:] ret.append(final_token) - self.token_counts[final_token] += 1 + self._increment_token_count(final_token) return ret def decode(self, tokens): @@ -103,7 +111,7 @@ def decode(self, tokens): return ret def _is_separator_char(self, c): - return c in self._SEPARATOR_CHAR_SET + return self._separator_char_mask[ord(c)] def _is_word_char(self, c): - return c not in self._SEPARATOR_CHAR_SET + return not self._is_separator_char(c) From 95942139b825ba19f18e3b740e2d5c9928411668 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Fri, 23 Jun 2017 11:15:13 -0700 Subject: [PATCH 2/5] gitignore update PiperOrigin-RevId: 159970261 --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 09f934869..dd84837dd 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ # Python egg metadata, regenerated from source files by setuptools. /*.egg-info + +# PyPI distribution artificats +build/ +dist/ From c2ce7a6bdf79f05524b6c07cad1762899371ec3d Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Fri, 23 Jun 2017 17:37:10 -0700 Subject: [PATCH 3/5] Play with CIFAR models and shake-shake a little. PiperOrigin-RevId: 160016542 --- tensor2tensor/data_generators/image.py | 4 - tensor2tensor/models/bluenet.py | 150 +++++++++++++++++++++ tensor2tensor/models/bluenet_test.py | 54 ++++++++ tensor2tensor/models/common_layers.py | 46 +++++++ tensor2tensor/models/common_layers_test.py | 9 ++ tensor2tensor/models/models.py | 1 + tensor2tensor/models/xception.py | 10 ++ tensor2tensor/models/xception_test.py | 2 +- 8 files changed, 271 insertions(+), 5 deletions(-) create mode 100644 tensor2tensor/models/bluenet.py create mode 100644 tensor2tensor/models/bluenet_test.py diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index 55b5f2fc7..88bfef4e6 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -200,10 +200,6 @@ def cifar10_generator(tmp_dir, training, how_many, start_from=0): ]) labels = data["labels"] all_labels.extend([labels[j] for j in xrange(num_images)]) - # Shuffle the data to make sure classes are well distributed. - data = zip(all_images, all_labels) - random.shuffle(data) - all_images, all_labels = zip(*data) return image_generator(all_images[start_from:start_from + how_many], all_labels[start_from:start_from + how_many]) diff --git a/tensor2tensor/models/bluenet.py b/tensor2tensor/models/bluenet.py new file mode 100644 index 000000000..bb7119a15 --- /dev/null +++ b/tensor2tensor/models/bluenet.py @@ -0,0 +1,150 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BlueNet: and out of the blue network to experiment with shake-shake.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensor2tensor.models import common_hparams +from tensor2tensor.models import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +def residual_module(x, hparams, train, n, sep): + """A stack of convolution blocks with residual connection.""" + k = (hparams.kernel_height, hparams.kernel_width) + dilations_and_kernels = [((1, 1), k) for _ in xrange(n)] + with tf.variable_scope("residual_module%d_sep%d" % (n, sep)): + y = common_layers.subseparable_conv_block( + x, + hparams.hidden_size, + dilations_and_kernels, + padding="SAME", + separability=sep, + name="block") + x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm") + return tf.nn.dropout(x, 1.0 - hparams.dropout * tf.to_float(train)) + + +def residual_module1(x, hparams, train): + return residual_module(x, hparams, train, 1, 1) + + +def residual_module1_sep(x, hparams, train): + return residual_module(x, hparams, train, 1, 0) + + +def residual_module2(x, hparams, train): + return residual_module(x, hparams, train, 2, 1) + + +def residual_module2_sep(x, hparams, train): + return residual_module(x, hparams, train, 2, 0) + + +def residual_module3(x, hparams, train): + return residual_module(x, hparams, train, 3, 1) + + +def residual_module3_sep(x, hparams, train): + return residual_module(x, hparams, train, 3, 0) + + +def norm_module(x, hparams, train): + del train # Unused. + return common_layers.layer_norm(x, hparams.hidden_size, name="norm_module") + + +def identity_module(x, hparams, train): + del hparams, train # Unused. + return x + + +def run_modules(blocks, cur, hparams, train, dp): + """Run blocks in parallel using dp as data_parallelism.""" + assert len(blocks) % dp.n == 0 + res = [] + for i in xrange(len(blocks) // dp.n): + res.extend(dp(blocks[i * dp.n:(i + 1) * dp.n], cur, hparams, train)) + return res + + +@registry.register_model +class BlueNet(t2t_model.T2TModel): + + def model_fn_body_sharded(self, sharded_features, train): + dp = self._data_parallelism + dp._reuse = False # pylint:disable=protected-access + hparams = self._hparams + blocks = [identity_module, norm_module, + residual_module1, residual_module1_sep, + residual_module2, residual_module2_sep, + residual_module3, residual_module3_sep] + inputs = sharded_features["inputs"] + + cur = tf.concat(inputs, axis=0) + cur_shape = cur.get_shape() + for i in xrange(hparams.num_hidden_layers): + with tf.variable_scope("layer_%d" % i): + processed = run_modules(blocks, cur, hparams, train, dp) + cur = common_layers.shakeshake(processed) + cur.set_shape(cur_shape) + + return list(tf.split(cur, len(inputs), axis=0)), 0.0 + + +@registry.register_hparams +def bluenet_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.batch_size = 4096 + hparams.hidden_size = 768 + hparams.dropout = 0.2 + hparams.symbol_dropout = 0.2 + hparams.label_smoothing = 0.1 + hparams.clip_grad_norm = 2.0 + hparams.num_hidden_layers = 8 + hparams.kernel_height = 3 + hparams.kernel_width = 3 + hparams.learning_rate_decay_scheme = "exp50k" + hparams.learning_rate = 0.05 + hparams.learning_rate_warmup_steps = 3000 + hparams.initializer_gain = 1.0 + hparams.weight_decay = 3.0 + hparams.num_sampled_classes = 0 + hparams.sampling_method = "argmax" + hparams.optimizer_adam_epsilon = 1e-6 + hparams.optimizer_adam_beta1 = 0.85 + hparams.optimizer_adam_beta2 = 0.997 + hparams.add_hparam("imagenet_use_2d", True) + return hparams + + +@registry.register_hparams +def bluenet_tiny(): + hparams = bluenet_base() + hparams.batch_size = 1024 + hparams.hidden_size = 128 + hparams.num_hidden_layers = 4 + hparams.learning_rate_decay_scheme = "none" + return hparams diff --git a/tensor2tensor/models/bluenet_test.py b/tensor2tensor/models/bluenet_test.py new file mode 100644 index 000000000..70996ab02 --- /dev/null +++ b/tensor2tensor/models/bluenet_test.py @@ -0,0 +1,54 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BlueNet tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import numpy as np + +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.models import bluenet + +import tensorflow as tf + + +class BlueNetTest(tf.test.TestCase): + + def testBlueNet(self): + vocab_size = 9 + x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) + y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1)) + hparams = bluenet.bluenet_tiny() + p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, + vocab_size) + with self.test_session() as session: + features = { + "inputs": tf.constant(x, dtype=tf.int32), + "targets": tf.constant(y, dtype=tf.int32), + } + model = bluenet.BlueNet(hparams, p_hparams) + sharded_logits, _, _ = model.model_fn(features, True) + logits = tf.concat(sharded_logits, 0) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 30215e889..f9d63a464 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -58,6 +58,52 @@ def inverse_exp_decay(max_step, min_value=0.01): return inv_base**tf.maximum(float(max_step) - step, 0.0) +def shakeshake2_py(x, y, equal=False): + """The shake-shake sum of 2 tensors, python version.""" + alpha = 0.5 if equal else tf.random_uniform([]) + return alpha * x + (1.0 - alpha) * y + + +@function.Defun() +def shakeshake2_grad(x1, x2, dy): + """Overriding gradient for shake-shake of 2 tensors.""" + y = shakeshake2_py(x1, x2) + dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy]) + return dx + + +@function.Defun() +def shakeshake2_equal_grad(x1, x2, dy): + """Overriding gradient for shake-shake of 2 tensors.""" + y = shakeshake2_py(x1, x2, equal=True) + dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy]) + return dx + + +@function.Defun(grad_func=shakeshake2_grad) +def shakeshake2(x1, x2): + """The shake-shake function with a different alpha for forward/backward.""" + return shakeshake2_py(x1, x2) + + +@function.Defun(grad_func=shakeshake2_equal_grad) +def shakeshake2_eqgrad(x1, x2): + """The shake-shake function with a different alpha for forward/backward.""" + return shakeshake2_py(x1, x2) + + +def shakeshake(xs, equal_grad=False): + """Multi-argument shake-shake, currently approximated by sums of 2.""" + if len(xs) == 1: + return xs[0] + div = (len(xs) + 1) // 2 + arg1 = shakeshake(xs[:div], equal_grad=equal_grad) + arg2 = shakeshake(xs[div:], equal_grad=equal_grad) + if equal_grad: + return shakeshake2_eqgrad(arg1, arg2) + return shakeshake2(arg1, arg2) + + def standardize_images(x): """Image standardization on batches (tf.image.per_image_standardization).""" with tf.name_scope("standardize_images", [x]): diff --git a/tensor2tensor/models/common_layers_test.py b/tensor2tensor/models/common_layers_test.py index 2bd6a53ad..3839b9d36 100644 --- a/tensor2tensor/models/common_layers_test.py +++ b/tensor2tensor/models/common_layers_test.py @@ -65,6 +65,15 @@ def testEmbedding(self): res = session.run(y) self.assertEqual(res.shape, (3, 5, 16)) + def testShakeShake(self): + x = np.random.rand(5, 7) + with self.test_session() as session: + x = tf.constant(x, dtype=tf.float32) + y = common_layers.shakeshake([x, x, x, x, x]) + session.run(tf.global_variables_initializer()) + inp, res = session.run([x, y]) + self.assertAllClose(res, inp) + def testConv(self): x = np.random.rand(5, 7, 1, 11) with self.test_session() as session: diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index 536a58966..b8f0811e5 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -24,6 +24,7 @@ from tensor2tensor.models import attention_lm from tensor2tensor.models import attention_lm_moe +from tensor2tensor.models import bluenet from tensor2tensor.models import bytenet from tensor2tensor.models import lstm from tensor2tensor.models import modalities diff --git a/tensor2tensor/models/xception.py b/tensor2tensor/models/xception.py index b6e271c36..01b5adb78 100644 --- a/tensor2tensor/models/xception.py +++ b/tensor2tensor/models/xception.py @@ -87,3 +87,13 @@ def xception_base(): hparams.optimizer_adam_beta2 = 0.997 hparams.add_hparam("imagenet_use_2d", True) return hparams + + +@registry.register_hparams +def xception_tiny(): + hparams = xception_base() + hparams.batch_size = 1024 + hparams.hidden_size = 128 + hparams.num_hidden_layers = 4 + hparams.learning_rate_decay_scheme = "none" + return hparams diff --git a/tensor2tensor/models/xception_test.py b/tensor2tensor/models/xception_test.py index 106604659..4eabb387a 100644 --- a/tensor2tensor/models/xception_test.py +++ b/tensor2tensor/models/xception_test.py @@ -34,7 +34,7 @@ def testXception(self): vocab_size = 9 x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1)) - hparams = xception.xception_base() + hparams = xception.xception_tiny() p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, vocab_size) with self.test_session() as session: From b53d6df93418628096a09e203c6fe3b0daafbd62 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Fri, 23 Jun 2017 18:00:22 -0700 Subject: [PATCH 4/5] Bump version to 1.0.7 PiperOrigin-RevId: 160018021 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0bf0c8739..5b2d423f8 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.0.6', + version='1.0.7', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', From d578f5210e4f0ba345a5e83bd111b4c3b2f2ed57 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Fri, 23 Jun 2017 18:06:47 -0700 Subject: [PATCH 5/5] internal merge PiperOrigin-RevId: 160018490 --- .../data_generators/generator_utils.py | 1 + tensor2tensor/data_generators/image.py | 2 +- tensor2tensor/data_generators/text_encoder.py | 87 ++++++++++++------- tensor2tensor/data_generators/tokenizer.py | 23 ++--- 4 files changed, 66 insertions(+), 47 deletions(-) diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 35e61d7cc..75d319cd8 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -28,6 +28,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin +import six.moves.urllib_request from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder from tensor2tensor.data_generators.tokenizer import Tokenizer diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index 88bfef4e6..e7e740192 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function -import cPickle import gzip import io import json @@ -30,6 +29,7 @@ # Dependency imports import numpy as np +from six.moves import cPickle from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin from tensor2tensor.data_generators import generator_utils diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index b170013ea..a219a6b8d 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -23,6 +23,8 @@ from __future__ import division from __future__ import print_function +from collections import defaultdict + # Dependency imports import six @@ -35,6 +37,10 @@ PAD = '' EOS = '' RESERVED_TOKENS = [PAD, EOS] +if six.PY2: + RESERVED_TOKENS_BYTES = RESERVED_TOKENS +else: + RESERVED_TOKENS_BYTES = [bytes(PAD, 'ascii'), bytes(EOS, 'ascii')] class TextEncoder(object): @@ -87,17 +93,25 @@ class ByteTextEncoder(TextEncoder): """Encodes each byte to an id. For 8-bit strings only.""" def encode(self, s): - return [ord(c) + self._num_reserved_ids for c in s] + numres = self._num_reserved_ids + if six.PY2: + return [ord(c) + numres for c in s] + # Python3: explicitly convert to UTF-8 + return [c + numres for c in s.encode('utf-8')] def decode(self, ids): + numres = self._num_reserved_ids decoded_ids = [] + int2byte = six.int2byte for id_ in ids: - if 0 <= id_ < self._num_reserved_ids: - decoded_ids.append(RESERVED_TOKENS[int(id_)]) + if 0 <= id_ < numres: + decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) else: - decoded_ids.append(chr(id_)) - - return ''.join(decoded_ids) + decoded_ids.append(int2byte(id_ - numres)) + if six.PY2: + return ''.join(decoded_ids) + # Python3: join byte arrays and then decode string + return b''.join(decoded_ids).decode('utf-8') @property def vocab_size(self): @@ -111,20 +125,16 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2): """Initialize from a file, one token per line.""" super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) self._reverse = reverse - if vocab_filename is not None: - self._load_vocab_from_file(vocab_filename) + self._load_vocab_from_file(vocab_filename) def encode(self, sentence): """Converts a space-separated string of tokens to a list of ids.""" ret = [self._token_to_id[tok] for tok in sentence.strip().split()] - if self._reverse: - ret = ret[::-1] - return ret + return ret[::-1] if self._reverse else ret def decode(self, ids): - if self._reverse: - ids = ids[::-1] - return ' '.join([self._safe_id_to_token(i) for i in ids]) + seq = reversed(ids) if self._reverse else ids + return ' '.join([self._safe_id_to_token(i) for i in seq]) @property def vocab_size(self): @@ -243,15 +253,22 @@ def _escaped_token_to_subtokens(self, escaped_token): """ ret = [] pos = 0 - while pos < len(escaped_token): - end = len(escaped_token) - while True: + lesc = len(escaped_token) + while pos < lesc: + end = lesc + while end > pos: subtoken = self._subtoken_string_to_id.get(escaped_token[pos:end], -1) if subtoken != -1: break end -= 1 ret.append(subtoken) - pos = end + if end > pos: + pos = end + else: + # This kinda should not happen, but it does. Cop out by skipping the + # nonexistent subtoken from the returned list. + # print("Unable to find subtoken in string '{0}'".format(escaped_token)) + pos += 1 return ret @classmethod @@ -322,13 +339,13 @@ def build_from_token_counts(self, # then count the resulting potential subtokens, keeping the ones # with high enough counts for our new vocabulary. for i in xrange(num_iterations): - counts = {} + counts = defaultdict(int) for token, count in six.iteritems(token_counts): escaped_token = self._escape_token(token) # we will count all tails of the escaped_token, starting from boundaries # determined by our current segmentation. if i == 0: - starts = list(range(len(escaped_token))) + starts = xrange(len(escaped_token)) else: subtokens = self._escaped_token_to_subtokens(escaped_token) pos = 0 @@ -337,31 +354,33 @@ def build_from_token_counts(self, starts.append(pos) pos += len(self.subtoken_to_subtoken_string(subtoken)) for start in starts: - for end in xrange(start + 1, len(escaped_token) + 1): + for end in xrange(start + 1, len(escaped_token)): subtoken_string = escaped_token[start:end] - counts[subtoken_string] = counts.get(subtoken_string, 0) + count + counts[subtoken_string] += count # array of lists of candidate subtoken strings, by length len_to_subtoken_strings = [] for subtoken_string, count in six.iteritems(counts): - if count < min_count or len(subtoken_string) <= 1: + lsub = len(subtoken_string) + # all subtoken strings of length 1 are included regardless of count + if count < min_count and lsub != 1: continue - while len(len_to_subtoken_strings) <= len(subtoken_string): + while len(len_to_subtoken_strings) <= lsub: len_to_subtoken_strings.append([]) - len_to_subtoken_strings[len(subtoken_string)].append(subtoken_string) + len_to_subtoken_strings[lsub].append(subtoken_string) new_subtoken_strings = [] # consider the candidates longest to shortest, so that if we accept # a longer subtoken string, we can decrement the counts of its prefixes. for subtoken_strings in len_to_subtoken_strings[::-1]: for subtoken_string in subtoken_strings: count = counts[subtoken_string] - if count < min_count: + if count < min_count and len(subtoken_string) != 1: + # subtoken strings of length 1 are included regardless of count continue new_subtoken_strings.append((-count, subtoken_string)) for l in xrange(1, len(subtoken_string)): counts[subtoken_string[:l]] -= count - # make sure we have all single characters. - new_subtoken_strings.extend([(-counts.get(chr(i), 0), chr(i)) - for i in xrange(2**8)]) + # Make sure to include the underscore as a subtoken string + new_subtoken_strings.append((0, '_')) new_subtoken_strings.sort() self._init_from_list([''] * self._num_reserved_ids + [p[1] for p in new_subtoken_strings]) @@ -390,13 +409,19 @@ def _load_from_file(self, filename): subtoken_strings = [] with tf.gfile.Open(filename) as f: for line in f: - subtoken_strings.append(line.strip()[1:-1].decode('string-escape')) + if six.PY2: + subtoken_strings.append(line.strip()[1:-1].decode('string-escape')) + else: + subtoken_strings.append(line.strip()[1:-1]) self._init_from_list(subtoken_strings) def _store_to_file(self, filename): with tf.gfile.Open(filename, 'w') as f: for subtoken_string in self._all_subtoken_strings: - f.write('\'' + subtoken_string.encode('string-escape') + '\'\n') + if six.PY2: + f.write('\'' + subtoken_string.encode('string-escape') + '\'\n') + else: + f.write('\'' + subtoken_string + '\'\n') def _escape_token(self, token): r"""Translate '\'->'\\' and '_'->'\u', then append '_'. diff --git a/tensor2tensor/data_generators/tokenizer.py b/tensor2tensor/data_generators/tokenizer.py index 15b199907..3564aee2e 100644 --- a/tensor2tensor/data_generators/tokenizer.py +++ b/tensor2tensor/data_generators/tokenizer.py @@ -45,7 +45,7 @@ from __future__ import division from __future__ import print_function -import array +from collections import defaultdict import string # Dependency imports @@ -57,17 +57,10 @@ class Tokenizer(object): """Vocab for breaking words into wordpieces. """ - def __init__(self): - self._separator_chars = string.punctuation + string.whitespace - self._separator_char_mask = array.array( - "l", [chr(i) in self._separator_chars for i in xrange(256)]) - self.token_counts = dict() + _SEPARATOR_CHAR_SET = set(string.punctuation + string.whitespace) - def _increment_token_count(self, token): - if token in self.token_counts: - self.token_counts[token] += 1 - else: - self.token_counts[token] = 1 + def __init__(self): + self.token_counts = defaultdict(int) def encode(self, raw_text): """Encode a raw string as a list of tokens. @@ -87,11 +80,11 @@ def encode(self, raw_text): token = raw_text[token_start:pos] if token != " " or token_start == 0: ret.append(token) - self._increment_token_count(token) + self.token_counts[token] += 1 token_start = pos final_token = raw_text[token_start:] ret.append(final_token) - self._increment_token_count(final_token) + self.token_counts[final_token] += 1 return ret def decode(self, tokens): @@ -111,7 +104,7 @@ def decode(self, tokens): return ret def _is_separator_char(self, c): - return self._separator_char_mask[ord(c)] + return c in self._SEPARATOR_CHAR_SET def _is_word_char(self, c): - return not self._is_separator_char(c) + return c not in self._SEPARATOR_CHAR_SET