diff --git a/README.md b/README.md index 0c781d97f..236d279c2 100644 --- a/README.md +++ b/README.md @@ -217,7 +217,7 @@ and are encoded in [`tf.contrib.training.HParams`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py) objects. The `HParams` are available to both the problem specification and the model. A basic set of hyperparameters are defined in -[`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/models/common_hparams.py) +[`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/common_hparams.py) and hyperparameter set functions can compose other hyperparameter set functions. ### Trainer diff --git a/setup.py b/setup.py index dd80dfd48..f32e8508c 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.1.9', + version='1.2.0', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', @@ -26,8 +26,8 @@ 'six', ], extras_require={ - 'tensorflow': ['tensorflow>=1.2.0rc1'], - 'tensorflow_gpu': ['tensorflow-gpu>=1.2.0rc1'], + 'tensorflow': ['tensorflow>=1.3.0'], + 'tensorflow_gpu': ['tensorflow-gpu>=1.3.0'], }, tests_require=['nose'], test_suite='nose.collector', diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index 19de46fbf..f7ea7e1f2 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -42,7 +42,6 @@ from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import from tensor2tensor.data_generators import audio from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import image from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import snli from tensor2tensor.data_generators import wmt @@ -106,9 +105,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True), lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True) ), - "image_celeba_tune": ( - lambda: image.celeba_generator(FLAGS.tmp_dir, 162770), - lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)), "inference_snli32k": ( lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15), diff --git a/tensor2tensor/data_generators/algorithmic_test.py b/tensor2tensor/data_generators/algorithmic_test.py index 3802d1beb..bcd61216e 100644 --- a/tensor2tensor/data_generators/algorithmic_test.py +++ b/tensor2tensor/data_generators/algorithmic_test.py @@ -21,7 +21,8 @@ # Dependency imports -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin + from tensor2tensor.data_generators import algorithmic import tensorflow as tf diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 0078eb3f9..ec3a9d0af 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -43,3 +43,4 @@ pass # pylint: enable=g-import-not-at-top # pylint: enable=unused-import + diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index eadca9bd6..3e1086d37 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -301,9 +301,24 @@ def gunzip_file(gz_path, new_path): def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, generator_fn): - """Inner implementation for vocab generators.""" - vocab_filepath = os.path.join(data_dir, vocab_filename) - if tf.gfile.Exists(vocab_filepath): + """Inner implementation for vocab generators. + + Args: + data_dir: The base directory where data and vocab files are stored. If None, + then do not save the vocab even if it doesn't exist. + vocab_filename: relative filename where vocab file is stored + vocab_size: target size of the vocabulary constructed by SubwordTextEncoder + generator_fn: a generator that produces tokens from the vocabulary + + Returns: + A SubwordTextEncoder vocabulary object. + """ + if data_dir is None: + vocab_filepath = None + else: + vocab_filepath = os.path.join(data_dir, vocab_filename) + + if vocab_filepath is not None and tf.gfile.Exists(vocab_filepath): tf.logging.info("Found vocab file: %s", vocab_filepath) vocab = text_encoder.SubwordTextEncoder(vocab_filepath) return vocab @@ -316,7 +331,9 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, vocab = text_encoder.SubwordTextEncoder.build_to_target_size( vocab_size, token_counts, 1, 1e3) - vocab.store_to_file(vocab_filepath) + + if vocab_filepath is not None: + vocab.store_to_file(vocab_filepath) return vocab diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index d9a6be6ff..71f4f0920 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -66,7 +66,137 @@ def example_reading_spec(self, label_key=None): return data_fields, data_items_to_decoders -# French street names dataset. +@registry.register_problem("image_celeba_tune") +class ImageCeleba(ImageProblem): + """CelebA dataset, aligned and cropped images.""" + IMG_DATA = ("img_align_celeba.zip", + "https://drive.google.com/uc?export=download&" + "id=0B7EVK8r0v71pZjFTYXZWM3FlRnM") + LANDMARKS_DATA = ("celeba_landmarks_align", + "https://drive.google.com/uc?export=download&" + "id=0B7EVK8r0v71pd0FJY3Blby1HUTQ") + ATTR_DATA = ("celeba_attr", "https://drive.google.com/uc?export=download&" + "id=0B7EVK8r0v71pblRyaVFSWGxPY0U") + + LANDMARK_HEADINGS = ("lefteye_x lefteye_y righteye_x righteye_y " + "nose_x nose_y leftmouth_x leftmouth_y rightmouth_x " + "rightmouth_y").split() + ATTR_HEADINGS = ( + "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs " + "Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair " + "Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair " + "Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache " + "Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline " + "Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings " + "Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young" + ).split() + + def preprocess_examples(self, examples, unused_mode, unused_hparams): + + def resize(img, size): + return tf.to_int64( + tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA)) + + inputs = examples["inputs"] + # Remove boundaries in CelebA images. Remove 40 pixels each side + # vertically and 20 pixels each side horizontally. + inputs = tf.image.crop_to_bounding_box(inputs, 40, 20, 218 - 80, 178 - 40) + examples["inputs"] = resize(inputs, 8) + examples["targets"] = resize(inputs, 32) + return examples + + def hparams(self, defaults, model_hparams): + p = defaults + p.input_modality = {"inputs": ("image:identity_no_pad", None)} + p.target_modality = ("image:identity_no_pad", None) + p.batch_size_multiplier = 256 + p.max_expected_batch_size_per_shard = 4 + p.input_space_id = 1 + p.target_space_id = 1 + + def generator(self, tmp_dir, how_many, start_from=0): + """Image generator for CELEBA dataset. + + Args: + tmp_dir: path to temporary storage directory. + how_many: how many images and labels to generate. + start_from: from which image to start. + + Yields: + A dictionary representing the images with the following fields: + * image/encoded: the string encoding the image as JPEG, + * image/format: the string "jpeg" representing image format, + """ + out_paths = [] + for fname, url in [self.IMG_DATA, self.LANDMARKS_DATA, self.ATTR_DATA]: + path = generator_utils.maybe_download_from_drive(tmp_dir, fname, url) + out_paths.append(path) + + img_path, landmarks_path, attr_path = out_paths # pylint: disable=unbalanced-tuple-unpacking + unzipped_folder = img_path[:-4] + if not tf.gfile.Exists(unzipped_folder): + zipfile.ZipFile(img_path, "r").extractall(tmp_dir) + + with tf.gfile.Open(landmarks_path) as f: + landmarks_raw = f.read() + + with tf.gfile.Open(attr_path) as f: + attr_raw = f.read() + + def process_landmarks(raw_data): + landmarks = {} + lines = raw_data.split("\n") + headings = lines[1].strip().split() + for line in lines[2:-1]: + values = line.strip().split() + img_name = values[0] + landmark_values = [int(v) for v in values[1:]] + landmarks[img_name] = landmark_values + return landmarks, headings + + def process_attrs(raw_data): + attrs = {} + lines = raw_data.split("\n") + headings = lines[1].strip().split() + for line in lines[2:-1]: + values = line.strip().split() + img_name = values[0] + attr_values = [int(v) for v in values[1:]] + attrs[img_name] = attr_values + return attrs, headings + + img_landmarks, _ = process_landmarks(landmarks_raw) + img_attrs, _ = process_attrs(attr_raw) + + image_files = tf.gfile.Glob(unzipped_folder + "/*.jpg") + for filename in image_files[start_from:start_from + how_many]: + img_name = os.path.basename(filename) + landmarks = img_landmarks[img_name] + attrs = img_attrs[img_name] + + with tf.gfile.Open(filename, "r") as f: + encoded_image_data = f.read() + yield { + "image/encoded": [encoded_image_data], + "image/format": ["jpeg"], + "attributes": attrs, + "landmarks": landmarks, + } + + @property + def train_shards(self): + return 100 + + @property + def dev_shards(self): + return 10 + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + generator_utils.generate_dataset_and_shuffle( + self.generator(tmp_dir, 162770), # train + self.training_filepaths(data_dir, self.train_shards, shuffled=False), + self.generator(tmp_dir, 19867, 162770), # dev + self.dev_filepaths(data_dir, self.dev_shards, shuffled=False)) @registry.register_problem @@ -199,7 +329,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): "instructions at https://github.com/tensorflow/models/blob/master" "/inception/README.md#getting-started") - def preprocess_examples(self, examples, mode): + def preprocess_examples(self, examples, mode, _): return imagenet_preprocess_examples(examples, mode) @@ -638,7 +768,7 @@ def train_shards(self): def dev_shards(self): return 10 - def preprocess_examples(self, examples, mode): + def preprocess_examples(self, examples, mode, _): return imagenet_preprocess_examples(examples, mode) def generator(self, data_dir, tmp_dir, is_training): @@ -700,41 +830,3 @@ class ImageMsCocoTokens32k(ImageMsCocoTokens8k): @property def targeted_vocab_size(self): return 2**15 # 32768 - - -# URL and filename for CELEBA data. -_CELEBA_NAME = "img_align_celeba" -_CELEBA_URL = "https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM" - - -def _get_celeba(directory): - """Download and extract CELEBA to directory unless it is there.""" - # path = os.path.join(directory, _CELEBA_NAME) - path = generator_utils.maybe_download_from_drive(directory, _CELEBA_NAME, - _CELEBA_URL) - if not tf.gfile.Exists(path): - zipfile.ZipFile(path + ".zip", "r").extractall(directory) - - -def celeba_generator(tmp_dir, how_many, start_from=0): - """Image generator for CELEBA dataset. - - Args: - tmp_dir: path to temporary storage directory. - how_many: how many images and labels to generate. - start_from: from which image to start. - - Yields: - A dictionary representing the images with the following fields: - * image/encoded: the string encoding the image as JPEG, - * image/format: the string "jpeg" representing image format, - """ - _get_celeba(tmp_dir) - image_files = tf.gfile.Glob(os.path.join(tmp_dir, _CELEBA_NAME) + "/*.jpg") - for filename in image_files[start_from:start_from + how_many]: - with tf.gfile.Open(filename, "r") as f: - encoded_image_data = f.read() - yield { - "image/encoded": [encoded_image_data], - "image/format": ["jpeg"], - } diff --git a/tensor2tensor/data_generators/lm1b.py b/tensor2tensor/data_generators/lm1b.py index f9f220571..a3771e124 100644 --- a/tensor2tensor/data_generators/lm1b.py +++ b/tensor2tensor/data_generators/lm1b.py @@ -26,7 +26,8 @@ # Dependency imports -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin + from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import text_encoder from tensor2tensor.data_generators import tokenizer diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 60b1e842b..e4424e73e 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -116,7 +116,10 @@ class Problem(object): * generate_data(data_dir, tmp_dir) - Generate training and dev datasets into data_dir. - Additonal files, e.g. vocabulary files, should also be written to - data_dir. + data_dir. Vocab files are newline-separated files with each line + containing a token. The standard convention for the filename is to + set it to be + ${Problem.vocab_name}.${Problem.targeted_vocab_size} - Downloads and other files can be written to tmp_dir - If you have a training and dev generator, you can generate the training and dev datasets with @@ -200,22 +203,22 @@ def training_filepaths(self, data_dir, num_shards, shuffled): file_basename = self.dataset_filename() if not shuffled: file_basename += generator_utils.UNSHUFFLED_SUFFIX - return generator_utils.train_data_filenames( - file_basename, data_dir, num_shards) + return generator_utils.train_data_filenames(file_basename, data_dir, + num_shards) def dev_filepaths(self, data_dir, num_shards, shuffled): file_basename = self.dataset_filename() if not shuffled: file_basename += generator_utils.UNSHUFFLED_SUFFIX - return generator_utils.dev_data_filenames( - file_basename, data_dir, num_shards) + return generator_utils.dev_data_filenames(file_basename, data_dir, + num_shards) def test_filepaths(self, data_dir, num_shards, shuffled): file_basename = self.dataset_filename() if not shuffled: file_basename += generator_utils.UNSHUFFLED_SUFFIX - return generator_utils.test_data_filenames( - file_basename, data_dir, num_shards) + return generator_utils.test_data_filenames(file_basename, data_dir, + num_shards) def __init__(self, was_reversed=False, was_copy=False): """Create a Problem. @@ -412,10 +415,8 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): generator_utils.shuffle_dataset(all_paths) else: generator_utils.generate_dataset_and_shuffle( - self.generator(data_dir, tmp_dir, True), - self.training_filepaths(data_dir, self.num_shards, shuffled=False), - self.generator(data_dir, tmp_dir, False), - self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False)) + self.generator(data_dir, tmp_dir, True), train_paths, + self.generator(data_dir, tmp_dir, False), dev_paths) def feature_encoders(self, data_dir): if self.is_character_level: @@ -435,8 +436,9 @@ def hparams(self, defaults, unused_model_hparams): if self.has_inputs: source_vocab_size = self._encoders["inputs"].vocab_size - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, - source_vocab_size)} + p.input_modality = { + "inputs": (registry.Modalities.SYMBOL, source_vocab_size) + } target_vocab_size = self._encoders["targets"].vocab_size p.target_modality = (registry.Modalities.SYMBOL, target_vocab_size) if self.has_inputs: diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 4a6053613..63b835f38 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -461,18 +461,6 @@ def img2img_imagenet(unused_model_hparams): return p -def image_celeba(unused_model_hparams): - """Image CelebA dataset.""" - p = default_problem_hparams() - p.input_modality = {"inputs": ("image:identity_no_pad", None)} - p.target_modality = ("image:identity_no_pad", None) - p.batch_size_multiplier = 256 - p.max_expected_batch_size_per_shard = 4 - p.input_space_id = 1 - p.target_space_id = 1 - return p - - # Dictionary of named hyperparameter settings for various problems. # This is only accessed through the problem_hparams function below. PROBLEM_HPARAMS_MAP = { @@ -503,8 +491,6 @@ def image_celeba(unused_model_hparams): p, "wsj", 2**14, 2**9), "translate_ende_wmt_bpe32k": wmt_ende_bpe32k, - "image_celeba_tune": - image_celeba, "img2img_imagenet": img2img_imagenet, } diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index f6897d04d..c8a3bd1f9 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -56,13 +56,21 @@ _ESCAPE_CHARS = set(u"\\_u;0123456789") +# Conversion between Unicode and UTF-8, if required (on Python2). if six.PY2: - def native_to_unicode(s): return s if isinstance(s, unicode) else s.decode("utf8") # noqa: F821 - def unicode_to_native(s): return s.encode("utf-8") -else: - # No conversion required on Python >= 3 - def native_to_unicode(s): return s - def unicode_to_native(s): return s + + def native_to_unicode(s): + return s if isinstance(s, unicode) else s.decode("utf8") + + def unicode_to_native(s): + return s.encode("utf-8") +else: # No conversion required on Python >= 3. + + def native_to_unicode(s): + return s + + def unicode_to_native(s): + return s class TextEncoder(object): @@ -154,7 +162,22 @@ def __init__(self, reverse=False, vocab_list=None, num_reserved_ids=NUM_RESERVED_TOKENS): - """Initialize from a file or list, one token per line.""" + """Initialize from a file or list, one token per line. + + Handling of reserved tokens works as follows: + - When initializing from a list, we add reserved tokens to the vocab. + - When initializing from a file, we do not add reserved tokens to the vocab. + - When saving vocab files, we save reserved tokens to the file. + + Args: + vocab_filename: If not None, the full filename to read vocab from. If this + is not None, then vocab_list should be None. + reverse: Boolean indicating if tokens should be reversed during encoding + and decoding. + vocab_list: If not None, a list of elements of the vocabulary. If this is + not None, then vocab_filename should be None. + num_reserved_ids: Number of IDs to save for reserved tokens like . + """ super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) self._reverse = reverse if vocab_filename: @@ -180,39 +203,64 @@ def _safe_id_to_token(self, idx): return self._id_to_token.get(idx, "ID_%d" % idx) def _init_vocab_from_file(self, filename): - """Load vocab from a file.""" + """Load vocab from a file. + Args: + filename: The file to load vocabulary from. + """ def token_gen(): with tf.gfile.Open(filename) as f: for line in f: token = line.strip() yield token - self._init_vocab(token_gen()) + self._init_vocab(token_gen(), add_reserved_tokens=False) def _init_vocab_from_list(self, vocab_list): + """Initialize tokens from a list of tokens. + + It is ok if reserved tokens appear in the vocab list. They will be + removed. The set of tokens in vocab_list should be unique. + Args: + vocab_list: A list of tokens. + """ def token_gen(): for token in vocab_list: - yield token + if token not in RESERVED_TOKENS: + yield token self._init_vocab(token_gen()) - def _init_vocab(self, token_generator): + def _init_vocab(self, token_generator, add_reserved_tokens=True): """Initialize vocabulary with tokens from token_generator.""" + self._id_to_token = {} + non_reserved_start_index = 0 - # Add reserved tokens - self._id_to_token.update(dict(list(enumerate(RESERVED_TOKENS)))) + if add_reserved_tokens: + self._id_to_token.update(enumerate(RESERVED_TOKENS)) + non_reserved_start_index = len(RESERVED_TOKENS) - token_id = len(RESERVED_TOKENS) - for token in token_generator: - self._id_to_token[token_id] = token - token_id += 1 + self._id_to_token.update( + enumerate(token_generator, start=non_reserved_start_index)) # _token_to_id is the reverse of _id_to_token - self._token_to_id = dict([(v, k) - for k, v in six.iteritems(self._id_to_token)]) + self._token_to_id = dict((v, k) + for k, v in six.iteritems(self._id_to_token)) + + def store_to_file(self, filename): + """Write vocab file to disk. + + Vocab files have one token per line. The file ends in a newline. Reserved + tokens are written to the vocab file as well. + + Args: + filename: Full path of the file to store the vocab to. + """ + with tf.gfile.Open(filename, "w") as f: + for i in xrange(len(self._id_to_token)): + f.write(self._id_to_token[i] + "\n") def _escape_token(token, alphabet): @@ -297,7 +345,12 @@ class SubwordTextEncoder(TextEncoder): """ def __init__(self, filename=None): - """Initialize and read from a file, if provided.""" + """Initialize and read from a file, if provided. + + Args: + filename: filename from which to read vocab. If None, do not load a + vocab + """ self._alphabet = set() if filename is not None: self._load_from_file(filename) @@ -551,8 +604,26 @@ def dump(self): for i, s in sorted(subtoken_strings))) def _init_subtokens_from_list(self, subtoken_strings, reserved=0): - """Initialize token information from a list of subtoken strings.""" - self._all_subtoken_strings = [u""] * reserved + subtoken_strings + """Initialize token information from a list of subtoken strings. + + Args: + subtoken_strings: a list of subtokens + reserved: number of spaces to save at the beginning for reserved tokens + + Raises: + ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it + is not clear what the space is being reserved for, or when it will be + filled in. + """ + if reserved == 0: + self._all_subtoken_strings = subtoken_strings + elif reserved == len(RESERVED_TOKENS): + self._all_subtoken_strings = RESERVED_TOKENS + subtoken_strings + else: + # TODO(dtarlow): or should we fall back to the previous behavior and + # insert copies of "" for each reserved count? + raise ValueError("Unexpected value for reserved. What is being reserved?") + # we remember the maximum length of any subtoken to avoid having to # check arbitrarily long strings. self._max_subtoken_len = max([len(s) for s in subtoken_strings]) @@ -569,7 +640,11 @@ def _init_alphabet_from_tokens(self, tokens): self._alphabet |= _ESCAPE_CHARS def _load_from_file(self, filename): - """Load from a file.""" + """Load from a file. + + Args: + filename: filename to load vocabulary from + """ subtoken_strings = [] with tf.gfile.Open(filename) as f: for line in f: diff --git a/tensor2tensor/data_generators/text_encoder_test.py b/tensor2tensor/data_generators/text_encoder_test.py index 4142f8699..eadfcfb5e 100644 --- a/tensor2tensor/data_generators/text_encoder_test.py +++ b/tensor2tensor/data_generators/text_encoder_test.py @@ -21,6 +21,8 @@ from __future__ import unicode_literals import collections +import os +import shutil # Dependency imports import mock @@ -47,6 +49,50 @@ def test_unescape_token(self): 'Foo! Bar.\nunder_score back\\slash', unescaped) +class TokenTextEncoderTest(tf.test.TestCase): + + @classmethod + def setUpClass(cls): + """Make sure the test dir exists and is empty.""" + cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), 'encoder_test') + shutil.rmtree(cls.test_temp_dir, ignore_errors=True) + os.mkdir(cls.test_temp_dir) + + def test_save_and_reload(self): + """Test that saving and reloading doesn't change the vocab. + + Note that this test reads and writes to the filesystem, which necessitates + that this test size be "large". + """ + + corpus = 'A B C D E F G H I J K L M N O P Q R S T U V W X Y Z' + vocab_filename = os.path.join(self.test_temp_dir, 'abc.vocab') + + # Make text encoder from a list and store vocab to fake filesystem. + encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) + encoder.store_to_file(vocab_filename) + + # Load back the saved vocab file from the fake_filesystem. + new_encoder = text_encoder.TokenTextEncoder(vocab_filename) + + self.assertEqual(encoder._id_to_token, new_encoder._id_to_token) + self.assertEqual(encoder._token_to_id, new_encoder._token_to_id) + + def test_reserved_tokens_in_corpus(self): + """Test that we handle reserved tokens appearing in the corpus.""" + corpus = 'A B {} D E F {} G {}'.format(text_encoder.EOS, + text_encoder.EOS, + text_encoder.PAD) + + encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) + + all_tokens = encoder._id_to_token.values() + + # If reserved tokens are removed correctly, then the set of tokens will + # be unique. + self.assertEqual(len(all_tokens), len(set(all_tokens))) + + class SubwordTextEncoderTest(tf.test.TestCase): def test_encode_decode(self): diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index 3cdbac5db..d1c80f2e1 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -129,12 +129,3 @@ def generator(self, data_dir, tmp_dir, _): encoded = encoder.encode(page) + [EOS] encoded_title = encoder.encode(title) + [EOS] yield {"inputs": encoded_title, "targets": encoded} - - -@registry.register_problem -class LanguagemodelWikiFull8k(problem.Text2TextProblem): - """A language model on full English Wikipedia.""" - - @property - def targeted_vocab_size(self): - return 2**13 # 8192 diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 77636ff6c..d69e68f80 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -23,7 +23,8 @@ # Dependency imports -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin + from tensor2tensor.layers import common_layers from tensor2tensor.utils import expert_utils @@ -976,7 +977,13 @@ def coordinate_tensor(shape, axis): return tf.zeros(shape, dtype=tf.int32) + tf.reshape(r, r_shape) -def self_attention_expert(x, batch_coordinate, mask_right=True): +def self_attention_expert( + x, + batch_coordinate, + mask_right=True, + attention_kq_size=None, + attention_v_size=None, +): """Implementing attention that runs inside each expert. Args: @@ -987,6 +994,8 @@ def self_attention_expert(x, batch_coordinate, mask_right=True): positions from different sequences don't attend to each other. mask_right: A bool. If true, we will not attend to positions on the right, just as decoder self attention. + attention_kq_size (int): dimension used for the attention key, and query + attention_v_size (int): dimension used for the attention value Returns: out: A tensor of shape [batch, depth]. @@ -998,32 +1007,60 @@ def self_attention_expert(x, batch_coordinate, mask_right=True): """ depth = x.get_shape().as_list()[-1] length = tf.shape(batch_coordinate)[0] - batch_coordinate = tf.squeeze(batch_coordinate, 1) - bias = tf.to_float( - tf.not_equal(tf.expand_dims(batch_coordinate, 1), - tf.expand_dims(batch_coordinate, 0))) * -1e9 - if mask_right: - bias += tf.reshape( - attention_bias_lower_triangle(length), [length, length]) - # bias has shape [length, length] - bias = tf.reshape(bias, [1, 1, length, length]) - x = tf.reshape(x, [1, length, depth]) - out = multihead_attention(x, - None, - bias, - total_key_depth=depth, - total_value_depth=depth, - output_depth=depth, - num_heads=1, - dropout_rate=0.0) - out = tf.squeeze(out, 0) + + attention_kq_size = attention_kq_size or depth + attention_v_size = attention_v_size or depth + + def length_not_null(x, batch_coordinate): + """Branch of the graph only evaluated when length isn't null.""" + with tf.name_scope("expert_mask"): + batch_coordinate = tf.squeeze(batch_coordinate, 1) + # Convert to float first because of b/25387198 + batch_coordinate = tf.to_float(batch_coordinate) + bc_v = tf.expand_dims(batch_coordinate, 1) + bc_h = tf.expand_dims(batch_coordinate, 0) + bias = bc_v - bc_h # Broadcast to create [length, length] mask + bias = tf.minimum(1.0, tf.abs(bias)) # Theshold non zeros to 1.0 + bias *= -1e9 # Set non zeros to -infinity + + if mask_right: + bias += tf.reshape( + attention_bias_lower_triangle(length), [length, length]) + # bias has shape [length, length] + bias = tf.reshape(bias, [1, 1, length, length]) + x = tf.reshape(x, [1, length, depth]) + out = multihead_attention(x, + None, + bias, + total_key_depth=attention_kq_size, + total_value_depth=attention_v_size, + output_depth=depth, + num_heads=1, + dropout_rate=0.0) + out = tf.squeeze(out, 0) + + return out + + # If the length is empty, just forward an empty tensor (avoid having to + # evaluate multihead_attention with tensor having dim equal to zeros) + out = tf.cond( + tf.equal(length, 0), + lambda: tf.zeros(shape=[0, depth], dtype=tf.float32, name="empty_out"), + lambda: length_not_null(x, batch_coordinate), + ) return out # functools.partial(self_attention_expert, mask_right=, depth=) -def local_expert_attention(x, k, loss_coef, attention_num_experts, train=True, - mask_right=True): +def local_expert_attention( + x, + k, + loss_coef, + attention_num_experts, + train=True, + **kwargs +): """Attention using a mixture of experts. Positions sent to the same expert can attend to each other. @@ -1036,8 +1073,7 @@ def local_expert_attention(x, k, loss_coef, attention_num_experts, train=True, loss_coef: a scalar. A multiplier for the expert loss attention_num_experts: The number of experts to use train: a boolean for the current mode - mask_right: A boolean. If true, we will mask out positions to the right - for self-attention. + **kwargs: Arguments to forward to self_attention_expert Returns: y: a Tensor with shape [batch, length, depth] @@ -1051,7 +1087,7 @@ def local_expert_attention(x, k, loss_coef, attention_num_experts, train=True, return expert_utils.local_moe( x, train, - partial(self_attention_expert, mask_right=mask_right), + partial(self_attention_expert, **kwargs), attention_num_experts, k=k, loss_coef=loss_coef, diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index 6bb4d3e9d..d4751bb0d 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -68,6 +68,8 @@ def basic_params1(): learning_rate=0.1, sampling_method="argmax", # "argmax" or "random" problem_choice="adaptive", # "uniform", "adaptive", "distributed" + # expand the logits a piece at a time - saves memory. + factored_logits=int(False), multiply_embedding_mode="sqrt_depth", # Parameters related to mixtures of experts. moe_hidden_sizes="2048", # hidden layer sizes (comma-separated) @@ -77,13 +79,17 @@ def basic_params1(): # Sequences of operations to perform on layer input and layer output. # Used by common_layers.layer_preprocess, common_layers.layer_postprocess # Each character repsesnts an operation: - # d: apply dropout - # n: apply normalization (see norm_type and norm_epsilon) - # a: add layer input (residual connection - only during postprocess) + # none: no preprocessing + # d: apply dropout + # n: apply normalization (see norm_type and norm_epsilon) + # a: add layer input (residual connection - only during postprocess) + # The special string "none" is used instead of the empty string + # to indicate no pre/postprocesisng, since the empty string causes + # trouble for hyperparameter tuning. # TODO(noam): The current settings ("", "dan") are the published version # of the transformer. ("n", "da") seems better for harder-to-learn # models, so it should probably be the default. - layer_preprocess_sequence="", + layer_preprocess_sequence="none", layer_postprocess_sequence="dan", # dropout rate to use during layer_preprocess and layer_postprocess layer_prepostprocess_dropout=0.1, diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index 8621ddcb1..ad899bfbf 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -18,7 +18,10 @@ from __future__ import division from __future__ import print_function +from collections import defaultdict +import contextlib import math +import random # Dependency imports @@ -29,6 +32,7 @@ import tensorflow as tf from tensorflow.python.framework import function +from tensorflow.python.framework import ops # This is a global setting. When turned off, no @function.Defun is used. allow_defun = False @@ -55,13 +59,13 @@ def hard_tanh(x, saturation_limit=0.9): def inverse_exp_decay(max_step, min_value=0.01): """Inverse-decay exponentially from 0.01 to 1.0 reached at max_step.""" inv_base = tf.exp(tf.log(min_value) / float(max_step)) - step = tf.to_float(tf.contrib.framework.get_global_step()) + step = tf.to_float(tf.train.get_global_step()) return inv_base**tf.maximum(float(max_step) - step, 0.0) def inverse_lin_decay(max_step, min_value=0.01): """Inverse-decay linearly from 0.01 to 1.0 reached at max_step.""" - step = tf.to_float(tf.contrib.framework.get_global_step()) + step = tf.to_float(tf.train.get_global_step()) progress = tf.minimum(step / float(max_step), 1.0) return progress * (1.0 - min_value) + min_value @@ -485,14 +489,8 @@ def apply_norm(x, norm_type, depth, epsilon): "'noam', 'none'.") -def layer_prepostprocess(previous_value, - x, - sequence, - dropout_rate, - norm_type, - depth, - epsilon, - name): +def layer_prepostprocess(previous_value, x, sequence, dropout_rate, norm_type, + depth, epsilon, name): """Apply a sequence of functions to the input or output of a layer. The sequence is specified as a string which may contain the following @@ -518,6 +516,8 @@ def layer_prepostprocess(previous_value, a Tensor """ with tf.variable_scope(name): + if sequence == "none": + return x for c in sequence: if c == "a": x += previous_value @@ -553,7 +553,8 @@ def layer_preprocess(layer_input, hparams): assert "a" not in hparams.layer_preprocess_sequence, ( "No residual connections allowed in hparams.layer_preprocess_sequence") return layer_prepostprocess( - None, layer_input, + None, + layer_input, sequence=hparams.layer_preprocess_sequence, dropout_rate=hparams.layer_prepostprocess_dropout, norm_type=hparams.norm_type, @@ -585,7 +586,8 @@ def layer_postprocess(layer_input, layer_output, hparams): a Tensor """ return layer_prepostprocess( - layer_input, layer_output, + layer_input, + layer_output, sequence=hparams.layer_postprocess_sequence, dropout_rate=hparams.layer_prepostprocess_dropout, norm_type=hparams.norm_type, @@ -1424,6 +1426,7 @@ def padded_cross_entropy(logits, Args: logits: a `Tensor` with shape `[batch, timesteps, vocab_size]`. + optionally a FactoredTensor. labels: an integer `Tensor` with shape `[batch, timesteps]`. label_smoothing: a floating point `Scalar`. weights_fn: A function from labels to weights. @@ -1433,6 +1436,13 @@ def padded_cross_entropy(logits, loss_numerator: a `Scalar`. Sum of losses. loss_denominator: a `Scalar. The number of non-padding target tokens. """ + if isinstance(logits, FactoredTensor): + return padded_cross_entropy_factored( + logits, + labels, + label_smoothing, + weights_fn=weights_fn, + reduce_sum=reduce_sum) confidence = 1.0 - label_smoothing vocab_size = tf.shape(logits)[-1] with tf.name_scope("padded_cross_entropy", [logits, labels]): @@ -1625,3 +1635,325 @@ def ravanbakhsh_set_layer(layer_size, inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1), activation_fn=activation_fn, name=name) + + +def fn_device_dependency_dict(): + """State container for fn_device_dependency.""" + if not hasattr(tf.get_default_graph(), "dependency_dict"): + setattr(tf.get_default_graph(), "dependency_dict", defaultdict(list)) + return tf.get_default_graph().dependency_dict + + +@contextlib.contextmanager +def fn_device_dependency(name, device=""): + """Add control deps for name and device.""" + key = name + "_" + device + outs = [] + + def body(): + with tf.control_dependencies(fn_device_dependency_dict()[key]): + yield outs + assert outs + + deps = outs + if isinstance(outs[0], list) or isinstance(outs[0], tuple): + assert len(outs) == 1 + deps = outs[0] + fn_device_dependency_dict()[key] = deps + + if device: + with tf.device(device): + return body() + else: + return body() + + +def underlying_variable_ref(t): + """Find the underlying variable ref, ignoring Identity ops. + + Args: + t: a Tensor + + Returns: + a Tensor that is a variable ref, or None on error. + """ + while t.op.type == "Identity": + t = t.op.inputs[0] + if "Variable" in t.op.type: + return t + else: + return None + + +def underlying_variable(t): + """Find the underlying tf.Variable object. + + Args: + t: a Tensor + + Returns: + a tf.Varaible object. + """ + t = underlying_variable_ref(t) + assert t is not None + # make sure that the graph has a variable index and that it is up-to-date + if not hasattr(tf.get_default_graph(), "var_index"): + tf.get_default_graph().var_index = {} + var_index = tf.get_default_graph().var_index + for v in tf.global_variables()[len(var_index):]: + var_index[v.name] = v + return var_index[t.name] + + +def approximate_split(x, num_splits, axis=0): + """Split approximately equally into num_splits parts. + + Args: + x: a Tensor + num_splits: an integer + axis: an integer. + + Returns: + a list of num_splits Tensors. + """ + size = tf.shape(x)[axis] + size_splits = [tf.div(size + i, num_splits) for i in xrange(num_splits)] + return tf.split(x, size_splits, axis=axis) + + +class FactoredTensor(object): + """A concise factored representation of Tensor as two tensors. + + This class represents the tensor tf.matmul(a, b, transpose_b=True) + by storing the values of Tensors a and b. + + The reason for this is that the product may be too big to fully realize at + once, so it can be realized a part at a time. + + "a" may have extra leading dimensions, in which case they are flattened out + before computing the matrix product, then re-expanded afterwards. + """ + + def __init__(self, a, b): + self._a = a + self._b = b + + @property + def a(self): + return self._a + + @property + def b(self): + return self._b + + def to_tensor(self): + inner_dim = tf.shape(self.b)[1] + result_dim = tf.shape(self.b)[0] + flat_a = tf.reshape(self.a, [-1, inner_dim]) + product = tf.matmul(flat_a, self.b, transpose_b=True) + product_shape = tf.concat([tf.shape(self.a)[:-1], [result_dim]], 0) + product = tf.reshape(product, product_shape) + product.set_shape(self.a.get_shape().as_list()[:-1] + + [self.b.get_shape()[0]]) + return product + + +def _convert_factored_tensor_to_tensor(value, *args, **kwargs): + # call ops.convert_to_tensor to handle optional arguments appropriately + return ops.internal_convert_to_tensor(value.to_tensor(), *args, **kwargs) + + +tf.register_tensor_conversion_function(FactoredTensor, + _convert_factored_tensor_to_tensor) + + +def smoothing_cross_entropy_factored_grad(op, dy): + """Gradient function for smoothing_cross_entropy_factored.""" + a = op.inputs[0] + b = op.inputs[1] + labels = op.inputs[2] + confidence = op.inputs[3] + num_splits = 32 + vocab_size = tf.shape(b)[0] + labels = approximate_split(labels, num_splits) + a = approximate_split(a, num_splits) + dy = approximate_split(dy, num_splits) + b_grad = None + a_grad_parts = [] + deps = [] + for part in xrange(num_splits): + with tf.control_dependencies(deps): + logits = tf.matmul(a[part], b, transpose_b=True) + output_part = smoothing_cross_entropy(logits, labels[part], vocab_size, + confidence) + a_grad_part, b_grad_part = tf.gradients( + ys=[output_part], xs=[a[part], b], grad_ys=[dy[part]]) + a_grad_parts.append(a_grad_part) + if part > 0: + b_grad += b_grad_part + else: + b_grad = b_grad_part + deps = [b_grad, a_grad_part] + a_grad = tf.concat(a_grad_parts, 0) + return a_grad, b_grad, None, None + + +@function.Defun( + noinline=True, + python_grad_func=smoothing_cross_entropy_factored_grad, + compiled=True, + separate_compiled_gradients=True) +def smoothing_cross_entropy_factored(a, b, labels, confidence): + """Memory-efficient computation of smoothing cross-entropy. + + Avoids realizing the entire logits matrix at once. + + Args: + a: a Tensor with shape [batch, inner_dim] + b: a Tensor with shape [vocab_size, inner_dim] + labels: an integer Tensor with shape [batch] + confidence: a float + + Returns: + A Tensor with shape [batch] + """ + num_splits = 32 + vocab_size = tf.shape(b)[0] + labels = approximate_split(labels, num_splits) + a = approximate_split(a, num_splits) + parts = [] + for part in xrange(num_splits): + with tf.control_dependencies(parts[-1:]): + logits = tf.matmul(a[part], b, transpose_b=True) + parts.append( + smoothing_cross_entropy(logits, labels[part], vocab_size, confidence)) + return tf.concat(parts, 0) + + +def padded_cross_entropy_factored(factored_logits, + labels, + label_smoothing, + weights_fn=weights_nonzero, + reduce_sum=True): + """Memory-efficient computation of smoothing cross-entropy. + + Avoids realizing the entire logits matrix at once. + + Args: + factored_logits: a `FactoredTensor` representing a Tensor + with shape `[batch, timesteps, vocab_size]`. + labels: an integer `Tensor` with shape `[batch, timesteps]`. + label_smoothing: a floating point `Scalar`. + weights_fn: A function from labels to weights. + reduce_sum: a Boolean, whether to sum at the end or not. + + Returns: + loss_numerator: a `Scalar`. Sum of losses. + loss_denominator: a `Scalar. The number of non-padding target tokens. + """ + a = factored_logits.a + b = factored_logits.b + confidence = 1.0 - label_smoothing + with tf.name_scope("padded_cross_entropy_factored", [a, b, labels]): + labels_flat = tf.reshape(labels, [-1]) + a_flat = tf.reshape(a, [-1, tf.shape(b)[1]]) + xent = smoothing_cross_entropy_factored(a_flat, b, labels_flat, + tf.convert_to_tensor(confidence)) + xent = tf.reshape(xent, tf.shape(labels)) + weights = weights_fn(labels) + if not reduce_sum: + return xent * weights, weights + return tf.reduce_sum(xent * weights), tf.reduce_sum(weights) + + +def fn_with_custom_grad(grad_fn, use_global_vars=False): + """Decorator to create a subgraph with a custom gradient function. + + The subgraph created by the decorated function is NOT put in a Defun and so + does not suffer from the limitations of the Defun (all subgraph ops on the + same device, no summaries). + + Args: + grad_fn: function with signature + (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), + all of which are lists of Tensors. + use_global_vars: if True, variables will be the global variables created. + If False, will be the trainable variables. + + Returns: + Decorator for function such that the gradient is defined by grad_fn. + """ + + def dec(fn): + + def wrapped(*args): + return _fn_with_custom_grad( + fn, args, grad_fn, use_global_vars=use_global_vars) + + return wrapped + + return dec + + +def _fn_with_custom_grad(fn, inputs, grad_fn, use_global_vars=False): + """Create a subgraph with a custom gradient. + + Args: + fn: function that takes inputs as arguments and produces 1 or more Tensors. + inputs: list, will be passed as fn(*inputs). + grad_fn: function with signature + (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), + all of which are lists of Tensors. + use_global_vars: if True, variables will be the global variables created. + If False, will be the trainable variables. + + Returns: + fn(*inputs) + """ + with tf.variable_scope(None, default_name="fn_with_custom_grad") as vs: + inputs = list(inputs) + outputs = fn(*inputs) + if use_global_vars: + train_vars = list(vs.global_variables()) + else: + train_vars = list(vs.trainable_variables()) + + if grad_fn is None: + return outputs + else: + if not (isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = [outputs] + outputs = list(outputs) + + in_types = [t.dtype for t in inputs] + out_types = [t.dtype for t in outputs] + var_types = [t.dtype for t in train_vars] + + def custom_grad_fn(op, *dys): + """Custom grad fn applying grad_fn for identity Defun.""" + dys = list(dys) + fn_inputs = op.inputs[:len(inputs)] + fn_vars = op.inputs[len(inputs):len(inputs) + len(train_vars)] + fn_outputs = op.inputs[len(inputs) + len(train_vars):] + assert len(fn_outputs) == len(outputs) + assert len(fn_outputs) == len(dys) + + grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) + grad_outputs = [None] * len(fn_outputs) + return tuple(grad_inputs + grad_vars + grad_outputs) + + # The Defun takes as input the original inputs, the trainable variables + # created in fn, and the outputs. In the forward it passes through the + # outputs. In the backwards, it produces gradients for the original inputs + # and the trainable variables. + @function.Defun( + *(in_types + var_types + out_types), + func_name="identity_custom_grad%d" % random.randint(1, 10**9), + python_grad_func=custom_grad_fn, + shape_func=lambda _: [t.get_shape() for t in outputs]) + def identity(*args): + outs = args[len(inputs) + len(train_vars):] + return tuple([tf.identity(t) for t in outs]) + + id_out = identity(*(inputs + train_vars + outputs)) + return id_out diff --git a/tensor2tensor/layers/common_layers_test.py b/tensor2tensor/layers/common_layers_test.py index 3cf3f3374..61023938f 100644 --- a/tensor2tensor/layers/common_layers_test.py +++ b/tensor2tensor/layers/common_layers_test.py @@ -392,6 +392,165 @@ def testRavanbakhshSetLayer(self): actual = session.run(layer) self.assertEqual(actual.shape, (5, 4, 32)) + def testPaddingCrossEntropyFactored(self): + vocab_size = 19 + rows = 5 + cols = 4 + depth = 11 + label_smoothing = 0.1 + features = np.random.rand(rows, cols, depth) + weights = np.random.rand(vocab_size, depth) + labels = np.random.randint(0, vocab_size - 1, size=(rows, cols)) + with self.test_session() as session: + features = tf.to_float(features) + weights = tf.to_float(weights) + labels = tf.to_int32(labels) + logits = tf.matmul( + tf.reshape(features, [rows * cols, depth]), weights, transpose_b=True) + logits = tf.reshape(logits, [rows, cols, vocab_size]) + loss_num, loss_den = common_layers.padded_cross_entropy( + logits, labels, label_smoothing=label_smoothing, reduce_sum=False) + factored_logits = common_layers.FactoredTensor(features, weights) + loss_num_f, loss_den_f = common_layers.padded_cross_entropy_factored( + factored_logits, + labels=labels, + label_smoothing=label_smoothing, + reduce_sum=False) + num, den, num_f, den_f = session.run( + [loss_num, loss_den, loss_num_f, loss_den_f]) + self.assertEqual(num.shape, (rows, cols)) + self.assertEqual(den.shape, (rows, cols)) + self.assertEqual(num_f.shape, (rows, cols)) + self.assertEqual(den_f.shape, (rows, cols)) + self.assertAllClose(num, num_f) + self.assertAllClose(den, den_f) + + def testPaddingCrossEntropyFactoredGrad(self): + vocab_size = 19 + rows = 5 + cols = 4 + depth = 11 + label_smoothing = 0.1 + features = np.random.rand(rows, cols, depth) + weights = np.random.rand(vocab_size, depth) + labels = np.random.randint(0, vocab_size - 1, size=(rows, cols)) + with self.test_session() as session: + features = tf.to_float(features) + weights = tf.to_float(weights) + labels = tf.to_int32(labels) + logits = tf.matmul( + tf.reshape(features, [rows * cols, depth]), weights, transpose_b=True) + logits = tf.reshape(logits, [rows, cols, vocab_size]) + loss_num, loss_den = common_layers.padded_cross_entropy( + logits, labels, label_smoothing=label_smoothing, reduce_sum=False) + factored_logits = common_layers.FactoredTensor(features, weights) + loss_num_factored, loss_den_factored = ( + common_layers.padded_cross_entropy_factored( + factored_logits, + labels=labels, + label_smoothing=label_smoothing, + reduce_sum=False)) + df, dw = tf.gradients(ys=[loss_num, loss_den], xs=[features, weights]) + df_factored, dw_factored = tf.gradients( + ys=[loss_num_factored, loss_den_factored], xs=[features, weights]) + actual_df, actual_dw, actual_df_factored, actual_dw_factored = ( + session.run([df, dw, df_factored, dw_factored])) + self.assertEqual(actual_df.shape, (rows, cols, depth)) + self.assertEqual(actual_dw.shape, (vocab_size, depth)) + self.assertEqual(actual_df_factored.shape, (rows, cols, depth)) + self.assertEqual(actual_dw_factored.shape, (vocab_size, depth)) + self.assertAllClose(actual_df, actual_df_factored) + self.assertAllClose(actual_dw, actual_dw_factored) + + def testFactoredTensorImplicitConversion(self): + a = np.random.rand(3, 4, 5) + b = np.random.rand(6, 5) + c = np.random.rand(3, 4, 6) + with self.test_session() as session: + # a factored representation of a Tensor of shape (3, 4, 6) + factored = common_layers.FactoredTensor(tf.to_float(a), tf.to_float(b)) + # implicitly converts factored to a Tensor (performing the matmul) + d = factored + tf.to_float(c) + out = session.run(d) + self.assertEqual(out.shape, (3, 4, 6)) + + +class FnWithCustomGradTest(tf.test.TestCase): + + def testCorrectness(self): + + w = tf.random_uniform([6, 10]) + + def fn(a, b, c): + return tf.layers.dense( + a, + 10, + use_bias=False, + kernel_initializer=lambda shape, dtype, partition_info: w + ) + tf.matmul(b, c) + + def grad_fn(inputs, variables, outputs, grad_outputs): + outputs = outputs[0] + grad_outputs = grad_outputs[0] + grad_inputs = tf.gradients(outputs, inputs, grad_ys=grad_outputs) + grad_vars = tf.gradients(outputs, variables, grad_ys=grad_outputs) + return grad_inputs, grad_vars + + custom_fn = common_layers.fn_with_custom_grad(grad_fn)(fn) + + a = tf.random_uniform([11, 6]) + b = tf.random_uniform([11, 7]) + c = tf.random_uniform([7, 10]) + + out = fn(a, b, c) + custom_out = custom_fn(a, b, c) + self.assertEqual(out.get_shape().as_list(), + custom_out.get_shape().as_list()) + + loss = tf.reduce_mean(out) + custom_loss = tf.reduce_mean(custom_out) + + grads = tf.gradients(loss, [a, b, c] + [tf.trainable_variables()[0]]) + custom_grads = tf.gradients(custom_loss, + [a, b, c] + [tf.trainable_variables()[1]]) + + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + out_val, custom_out_val, grads_val, custom_grads_val = sess.run( + [out, custom_out, grads, custom_grads]) + self.assertAllClose(out_val, custom_out_val) + for g1, g2 in zip(grads_val, custom_grads_val): + self.assertAllClose(g1, g2) + + def testCustomGrad(self): + + def fn(a, b, c): + return tf.layers.dense(a, 10, use_bias=False) + tf.matmul(b, c) + + def grad_fn(inputs, variables, unused_outputs, unused_grad_outputs): + grad_inputs = [tf.ones_like(t) * (i + 1.) for i, t in enumerate(inputs)] + grad_vars = [ + tf.ones_like(t) * (i + len(inputs) + 1.) + for i, t in enumerate(variables) + ] + return grad_inputs, grad_vars + + a = tf.random_uniform([11, 6]) + b = tf.random_uniform([11, 7]) + c = tf.random_uniform([7, 10]) + w = tf.random_uniform([6, 10]) + out = common_layers.fn_with_custom_grad(grad_fn)(fn)(a, b, c) + loss = tf.reduce_mean(out) + grads = tf.gradients(loss, [a, b, c, tf.trainable_variables()[0]]) + expected_grads = [ + tf.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) + ] + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + g_val, eg_val = sess.run([grads, expected_grads]) + for g1, g2 in zip(g_val, eg_val): + self.assertAllClose(g1, g2) + if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index 01728ba24..57652dbec 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -112,12 +112,18 @@ def top(self, body_output, _): reuse = False with tf.variable_scope(scope_name, reuse=reuse): var = self._get_weights() - shape = tf.shape(body_output)[:-1] - body_output = tf.reshape(body_output, [-1, self._body_input_depth]) - logits = tf.matmul(body_output, var, transpose_b=True) - logits = tf.reshape(logits, tf.concat([shape, [self._vocab_size]], 0)) - # insert a channels dimension - return tf.expand_dims(logits, 3) + if (self._model_hparams.factored_logits and + self._model_hparams.mode == tf.contrib.learn.ModeKeys.TRAIN): + # insert channels dimension + body_output = tf.expand_dims(body_output, 3) + logits = common_layers.FactoredTensor(body_output, var) + else: + shape = tf.shape(body_output)[:-1] + body_output = tf.reshape(body_output, [-1, self._body_input_depth]) + logits = tf.matmul(body_output, var, transpose_b=True) + logits = tf.reshape( + logits, tf.concat([shape, [1, self._vocab_size]], 0)) + return logits @registry.register_image_modality diff --git a/tensor2tensor/layers/modalities_test.py b/tensor2tensor/layers/modalities_test.py index 0ccd13777..5813422ab 100644 --- a/tensor2tensor/layers/modalities_test.py +++ b/tensor2tensor/layers/modalities_test.py @@ -65,7 +65,43 @@ def testSymbolModalityTargets(self): symbol_modality_num_shards=4, hidden_size=hidden_size, label_smoothing=0.2, - shared_embedding_and_softmax_weights=0) + shared_embedding_and_softmax_weights=0, + factored_logits=0, + mode=tf.contrib.learn.ModeKeys.TRAIN) + body_output = -1 + np.random.random_integers( + 100, size=(batch_size, length, height, hidden_size)) + targets = -1 + np.random.random_integers( + vocab_size, size=(batch_size, length, height, 1)) + m = modalities.SymbolModality(model_hparams, vocab_size) + data_parallelism = expert_utils.Parallelism( + ["/device:CPU:0"] * num_datashards, reuse=True) + with self.test_session() as session: + sharded_body_output = tf.split(tf.to_float(body_output), num_datashards) + sharded_targets = tf.split(targets, num_datashards) + sharded_logits = m.top_sharded(sharded_body_output, sharded_targets, + data_parallelism) + train_loss = m.loss_sharded(sharded_logits, sharded_targets, + data_parallelism) + logits = tf.concat(sharded_logits, 0) + session.run(tf.global_variables_initializer()) + res1, res2 = session.run((logits, train_loss)) + self.assertEqual(res1.shape, (batch_size, length, height, 1, vocab_size)) + self.assertEqual(res2.shape, ()) + + def testSymbolModalityTargetsFactored(self): + batch_size = 10 + num_datashards = 5 + length = 6 + height = 7 + hidden_size = 9 + vocab_size = 11 + model_hparams = tf.contrib.training.HParams( + symbol_modality_num_shards=4, + hidden_size=hidden_size, + label_smoothing=0.2, + shared_embedding_and_softmax_weights=0, + factored_logits=1, + mode=tf.contrib.learn.ModeKeys.TRAIN) body_output = -1 + np.random.random_integers( 100, size=(batch_size, length, height, hidden_size)) targets = -1 + np.random.random_integers( diff --git a/tensor2tensor/layers/rev_block.py b/tensor2tensor/layers/rev_block.py index f85385e68..9def9f481 100644 --- a/tensor2tensor/layers/rev_block.py +++ b/tensor2tensor/layers/rev_block.py @@ -23,15 +23,14 @@ from __future__ import division from __future__ import print_function -import random import re # Dependency imports -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensor2tensor.layers import common_layers import tensorflow as tf -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") @@ -150,101 +149,6 @@ def _rev_block_forward(x1, return y1, y2 -def _underlying_variable(t): - """Find the underlying variable ref, ignoring Identity ops.""" - while t.op.type == "Identity": - t = t.op.inputs[0] - if t.dtype == dtypes.float32_ref and "Variable" in t.op.type: - return t - else: - return None - - -def fn_with_custom_grad(grad_fn): - """Decorator to create a subgraph with a custom gradient function. - - The subgraph created by the decorated function is NOT put in a Defun and so - does not suffer from the limitations of the Defun (all subgraph ops on the - same device, no summaries). - - Args: - grad_fn: function with signature - (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - - Returns: - Decorator for function such that the gradient is defined by grad_fn. - """ - - def dec(fn): - - def wrapped(*args): - return _fn_with_custom_grad(fn, args, grad_fn) - - return wrapped - - return dec - - -def _fn_with_custom_grad(fn, inputs, grad_fn): - """Create a subgraph with a custom gradient. - - Args: - fn: function that takes inputs as arguments and produces 1 or more Tensors. - inputs: list, will be passed as fn(*inputs). - grad_fn: function with signature - (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - - Returns: - fn(*inputs) - """ - with tf.variable_scope(None, default_name="fn_with_custom_grad") as vs: - inputs = list(inputs) - outputs = fn(*inputs) - train_vars = list(vs.trainable_variables()) - - if grad_fn is None: - return outputs - else: - if not (isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = [outputs] - outputs = list(outputs) - - in_types = [t.dtype for t in inputs] - out_types = [t.dtype for t in outputs] - var_types = [t.dtype for t in train_vars] - - def custom_grad_fn(op, *dys): - """Custom grad fn applying grad_fn for identity Defun.""" - dys = list(dys) - fn_inputs = op.inputs[:len(inputs)] - fn_vars = op.inputs[len(inputs):len(inputs) + len(train_vars)] - fn_outputs = op.inputs[len(inputs) + len(train_vars):] - assert len(fn_outputs) == len(outputs) - assert len(fn_outputs) == len(dys) - - grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) - grad_outputs = [None] * len(fn_outputs) - return tuple(grad_inputs + grad_vars + grad_outputs) - - # The Defun takes as input the original inputs, the trainable variables - # created in fn, and the outputs. In the forward it passes through the - # outputs. In the backwards, it produces gradients for the original inputs - # and the trainable variables. - @function.Defun( - *(in_types + var_types + out_types), - func_name="identity_custom_grad%d" % random.randint(1, 10**9), - python_grad_func=custom_grad_fn, - shape_func=lambda _: [t.get_shape() for t in outputs]) - def identity(*args): - outs = args[len(inputs) + len(train_vars):] - return tuple([tf.identity(t) for t in outs]) - - id_out = identity(*(inputs + train_vars + outputs)) - return id_out - - def rev_block(x1, x2, f, @@ -330,7 +234,7 @@ def custom_grad_fn(inputs, variables, ys, grad_ys): g_vars_idxs = [[] for _ in range(num_layers)] for i, t in enumerate(variables): - ref = _underlying_variable(t) + ref = common_layers.underlying_variable_ref(t) # Use the name to identify the layer number and function (f or g) regex = LAYER_RE.match(ref.name) @@ -396,7 +300,7 @@ def custom_grad_fn(inputs, variables, ys, grad_ys): return [grad_x1, grad_x2] + side_input_grads, variable_grads # Need a forward function with positional arguments - @fn_with_custom_grad(custom_grad_fn if is_training else None) + @common_layers.fn_with_custom_grad(custom_grad_fn if is_training else None) def forward(x1, x2, *side_inputs): f_side = side_inputs[:len(f_side_input)] g_side = side_inputs[len(f_side_input):] diff --git a/tensor2tensor/layers/rev_block_test.py b/tensor2tensor/layers/rev_block_test.py index dd4a62993..5aecc8ea3 100644 --- a/tensor2tensor/layers/rev_block_test.py +++ b/tensor2tensor/layers/rev_block_test.py @@ -55,8 +55,9 @@ def g(x): # pylint: disable=function-redefined if g_side_input is None: g_side_input = [] - x = tf.random_uniform([self.BATCH_SIZE, self.CHANNELS], dtype=tf.float32) - x1, x2 = tf.split(x, 2, axis=1) + if x is None: + x = tf.random_uniform([self.BATCH_SIZE, self.CHANNELS], dtype=tf.float32) + x1, x2 = tf.split(x, 2, axis=-1) with tf.variable_scope("rev_test") as vs: y1_rev, y2_rev = rev_block.rev_block( @@ -121,82 +122,19 @@ def f2(x): self._testRevBlock(f=[f1, f2, f1, f2]) + def testConvAndBatchNorm(self): -class FnWithCustomGradTest(tf.test.TestCase): + x = tf.random_uniform( + [self.BATCH_SIZE, 10, self.CHANNELS], dtype=tf.float32) - def testCorrectness(self): + def f(x): + x = tf.layers.conv1d(x, self.CHANNELS // 2, 3, padding="same") + x = tf.layers.batch_normalization(x, training=True) + x = tf.layers.conv1d(x, self.CHANNELS // 2, 3, padding="same") + x = tf.layers.batch_normalization(x, training=True) + return x - w = tf.random_uniform([6, 10]) - - def fn(a, b, c): - return tf.layers.dense( - a, - 10, - use_bias=False, - kernel_initializer=lambda shape, dtype, partition_info: w - ) + tf.matmul(b, c) - - def grad_fn(inputs, variables, outputs, grad_outputs): - outputs = outputs[0] - grad_outputs = grad_outputs[0] - grad_inputs = tf.gradients(outputs, inputs, grad_ys=grad_outputs) - grad_vars = tf.gradients(outputs, variables, grad_ys=grad_outputs) - return grad_inputs, grad_vars - - custom_fn = rev_block.fn_with_custom_grad(grad_fn)(fn) - - a = tf.random_uniform([11, 6]) - b = tf.random_uniform([11, 7]) - c = tf.random_uniform([7, 10]) - - out = fn(a, b, c) - custom_out = custom_fn(a, b, c) - self.assertEqual(out.get_shape().as_list(), - custom_out.get_shape().as_list()) - - loss = tf.reduce_mean(out) - custom_loss = tf.reduce_mean(custom_out) - - grads = tf.gradients(loss, [a, b, c] + [tf.trainable_variables()[0]]) - custom_grads = tf.gradients(custom_loss, - [a, b, c] + [tf.trainable_variables()[1]]) - - with self.test_session() as sess: - sess.run(tf.global_variables_initializer()) - out_val, custom_out_val, grads_val, custom_grads_val = sess.run( - [out, custom_out, grads, custom_grads]) - self.assertAllClose(out_val, custom_out_val) - for g1, g2 in zip(grads_val, custom_grads_val): - self.assertAllClose(g1, g2) - - def testCustomGrad(self): - - def fn(a, b, c): - return tf.layers.dense(a, 10, use_bias=False) + tf.matmul(b, c) - - def grad_fn(inputs, variables, unused_outputs, unused_grad_outputs): - grad_inputs = [tf.ones_like(t) * (i + 1.) for i, t in enumerate(inputs)] - grad_vars = [ - tf.ones_like(t) * (i + len(inputs) + 1.) - for i, t in enumerate(variables) - ] - return grad_inputs, grad_vars - - a = tf.random_uniform([11, 6]) - b = tf.random_uniform([11, 7]) - c = tf.random_uniform([7, 10]) - w = tf.random_uniform([6, 10]) - out = rev_block.fn_with_custom_grad(grad_fn)(fn)(a, b, c) - loss = tf.reduce_mean(out) - grads = tf.gradients(loss, [a, b, c, tf.trainable_variables()[0]]) - expected_grads = [ - tf.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) - ] - with self.test_session() as sess: - sess.run(tf.global_variables_initializer()) - g_val, eg_val = sess.run([grads, expected_grads]) - for g1, g2 in zip(g_val, eg_val): - self.assertAllClose(g1, g2) + self._testRevBlock(x=x, f=f) if __name__ == "__main__": diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index 9c55eadd6..5bb63c303 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -32,6 +32,7 @@ from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_hparams from tensor2tensor.layers import common_layers +from tensor2tensor.utils import diet from tensor2tensor.utils import expert_utils from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model @@ -76,9 +77,20 @@ def postprocess(x, y): 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] + if hparams.diet_experts: + hsize, = moe_hidden_sizes + + def _diet_expert(x): + return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) + + expert_fn = _diet_expert + else: + expert_fn = expert_utils.ffn_expert_fn( + hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): - with tf.variable_scope("attention"): + with tf.variable_scope( + "attention_{}".format(hparams.attention_moe_type)): x = preprocess(x) if hparams.attention_moe_type == AttentionMoeType.NONE: y = dp( @@ -100,9 +112,11 @@ def postprocess(x, y): loss_coef=1e-2, attention_num_experts=hparams.attention_num_experts, train=hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, - mask_right=True) + mask_right=True, + attention_kq_size=hparams.attention_kq_size, + attention_v_size=hparams.attention_v_size) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? - extra_loss += tf.add_n(loss)/dp.n + extra_loss += tf.add_n(loss) / dp.n else: raise ValueError("Only {} supported for now.".format( AttentionMoeType.get_choices())) @@ -115,9 +129,7 @@ def postprocess(x, y): preprocess(x), hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, input_size=hparams.hidden_size, - expert_fn=expert_utils.ffn_expert_fn( - hparams.hidden_size, moe_hidden_sizes, - hparams.hidden_size), + expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) @@ -148,9 +160,8 @@ def attention_lm_moe_prepare_decoder(targets, hparams): to implement masked attention and possibly baises for diagonal alignments """ if hparams.prepend_mode == "prepend_inputs_full_attention": - decoder_self_attention_bias = ( - common_attention.attention_bias_prepended( - common_attention.embedding_to_padding(targets))) + decoder_self_attention_bias = (common_attention.attention_bias_prepended( + common_attention.embedding_to_padding(targets))) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) @@ -206,6 +217,20 @@ def attention_lm_moe_base(): # moe params. local attention moe. hparams.add_hparam("attention_moe_type", AttentionMoeType.NONE) hparams.add_hparam("attention_num_experts", 16) + # Key, query and value dimensions for the attention + hparams.add_hparam("attention_kq_size", 64) + hparams.add_hparam("attention_v_size", 64) + hparams.add_hparam("diet_experts", int(False)) + return hparams + + +@registry.register_hparams +def attention_lm_moe_base_ae(): + """Base model with attention expert.""" + hparams = attention_lm_moe_base() + hparams.attention_moe_type = AttentionMoeType.LOCAL + hparams.max_length = hparams.batch_size + hparams.eval_drop_long_sequences = int(True) return hparams @@ -252,8 +277,8 @@ def attention_lm_attention_moe_tiny(): """ hparams = attention_lm_moe_small() hparams.moe_layers = "" - hparams.attention_num_experts = 16 - hparams.filter_size = 512 + hparams.attention_num_experts = 128 + hparams.filter_size = 8192 hparams.attention_moe_type = AttentionMoeType.LOCAL return hparams @@ -303,6 +328,32 @@ def attention_lm_moe_large(): return hparams +@registry.register_hparams +def attention_lm_moe_large_diet(): + hparams = attention_lm_moe_large() + hparams.diet_experts = int(True) + return hparams + + +@registry.register_hparams +def attention_lm_moe_32b_diet(): + """Unnecessarily large model with 32B params - because we can.""" + hparams = attention_lm_moe_large_diet() + hparams.moe_hidden_sizes = "16384" + hparams.moe_num_experts = 1024 + return hparams + + +@registry.register_hparams +def attention_lm_moe_24b_diet(): + """Unnecessarily large model with 24B params - because we can.""" + hparams = attention_lm_moe_large_diet() + hparams.moe_hidden_sizes = "12288" + hparams.moe_num_experts = 1024 + hparams.batch_size = 4096 + return hparams + + @registry.register_hparams def attention_lm_moe_translation(): """Version to use for seq2seq.""" diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index af609e22c..7c31f4e05 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -33,12 +33,12 @@ from tensor2tensor.models import lstm from tensor2tensor.models import multimodel from tensor2tensor.models import neural_gpu -from tensor2tensor.models import rev_transformer from tensor2tensor.models import shake_shake from tensor2tensor.models import slicenet from tensor2tensor.models import transformer from tensor2tensor.models import transformer_alternative from tensor2tensor.models import transformer_moe +from tensor2tensor.models import transformer_revnet from tensor2tensor.models import transformer_vae from tensor2tensor.models import xception # pylint: enable=unused-import diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 0eed2dbdb..47db28c30 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -190,7 +190,7 @@ def transformer_encoder(encoder_input, """ x = encoder_input with tf.variable_scope(name): - for layer in xrange(hparams.num_hidden_layers): + for layer in xrange(hparams.num_encoder_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( @@ -233,7 +233,7 @@ def transformer_decoder(decoder_input, """ x = decoder_input with tf.variable_scope(name): - for layer in xrange(hparams.num_hidden_layers): + for layer in xrange(hparams.num_decoder_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( @@ -324,6 +324,9 @@ def transformer_base(): hparams.shared_embedding_and_softmax_weights = int(True) hparams.add_hparam("filter_size", 2048) # Add new ones like this. + # layer-related flags + hparams.add_hparam("num_encoder_layers", hparams.num_hidden_layers) + hparams.add_hparam("num_decoder_layers", hparams.num_hidden_layers) # attention-related flags hparams.add_hparam("num_heads", 8) hparams.add_hparam("attention_key_channels", 0) diff --git a/tensor2tensor/models/rev_transformer.py b/tensor2tensor/models/transformer_revnet.py similarity index 88% rename from tensor2tensor/models/rev_transformer.py rename to tensor2tensor/models/transformer_revnet.py index d1392a1ee..942a00660 100644 --- a/tensor2tensor/models/rev_transformer.py +++ b/tensor2tensor/models/transformer_revnet.py @@ -31,7 +31,7 @@ @registry.register_model -class RevTransformer(transformer.Transformer): +class TransformerRevnet(transformer.Transformer): """Reversible Residual Transformer. Layers are reversible and are recomputed on the backward pass. @@ -63,10 +63,10 @@ def model_fn_body(self, features): 1.0 - hparams.layer_prepostprocess_dropout) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) - encoder_output = rev_transformer_encoder( + encoder_output = transformer_revnet_encoder( encoder_input, encoder_self_attention_bias, hparams) - decoder_output = rev_transformer_decoder( + decoder_output = transformer_revnet_decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) @@ -74,10 +74,10 @@ def model_fn_body(self, features): return decoder_output -def rev_transformer_encoder(encoder_input, - encoder_self_attention_bias, - hparams, - name="encoder"): +def transformer_revnet_encoder(encoder_input, + encoder_self_attention_bias, + hparams, + name="encoder"): """A stack of transformer layers. Args: @@ -137,12 +137,12 @@ def g(x): return common_layers.layer_preprocess(y, hparams) -def rev_transformer_decoder(decoder_input, - encoder_output, - decoder_self_attention_bias, - encoder_decoder_attention_bias, - hparams, - name="decoder"): +def transformer_revnet_decoder(decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + hparams, + name="decoder"): """A stack of transformer layers. Args: @@ -218,8 +218,8 @@ def g(x): @registry.register_hparams -def rev_transformer_base(): - """Base hparams for RevTransformer.""" +def transformer_revnet_base(): + """Base hparams for TransformerRevnet.""" hparams = transformer.transformer_big() # Use settings from transformer_n_da @@ -231,11 +231,11 @@ def rev_transformer_base(): @registry.register_hparams -def rev_transformer_big(): - """Base hparams for RevTransformer.""" - hparams = rev_transformer_base() +def transformer_revnet_big(): + """Base hparams for TransformerRevnet.""" + hparams = transformer_revnet_base() - # The RevTransformer uses significantly less memory than the Transformer. + # The TransformerRevnet uses significantly less memory than the Transformer. # Increase batch size and model size. hparams.batch_size *= 2 hparams.hidden_size *= 2 diff --git a/tensor2tensor/models/rev_transformer_test.py b/tensor2tensor/models/transformer_revnet_test.py similarity index 87% rename from tensor2tensor/models/rev_transformer_test.py rename to tensor2tensor/models/transformer_revnet_test.py index da9e15f72..66b493b0b 100644 --- a/tensor2tensor/models/rev_transformer_test.py +++ b/tensor2tensor/models/transformer_revnet_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for RevTransformer.""" +"""Tests for TransformerRevnet.""" from __future__ import absolute_import from __future__ import division @@ -24,13 +24,13 @@ import numpy as np from tensor2tensor.data_generators import problem_hparams -from tensor2tensor.models import rev_transformer +from tensor2tensor.models import transformer_revnet import tensorflow as tf -def rev_transformer_test(): - hparams = rev_transformer.rev_transformer_base() +def transformer_revnet_test(): + hparams = transformer_revnet.transformer_revnet_base() hparams.num_hidden_layers = 2 hparams.hidden_size = 128 hparams.filter_size = 512 @@ -38,14 +38,14 @@ def rev_transformer_test(): return hparams -class RevTransformerTest(tf.test.TestCase): +class TransformerRevnetTest(tf.test.TestCase): def testTransformer(self): batch_size = 3 input_length = 5 target_length = 7 vocab_size = 9 - hparams = rev_transformer_test() + hparams = transformer_revnet_test() p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, vocab_size) hparams.problems = [p_hparams] @@ -58,7 +58,7 @@ def testTransformer(self): "targets": tf.constant(targets, dtype=tf.int32), "target_space_id": tf.constant(1, dtype=tf.int32), } - model = rev_transformer.RevTransformer( + model = transformer_revnet.TransformerRevnet( hparams, tf.contrib.learn.ModeKeys.TRAIN, p_hparams) sharded_logits, _ = model.model_fn(features) logits = tf.concat(sharded_logits, 0) diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 8f4d26339..391824524 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -29,36 +29,65 @@ import tensorflow as tf +BATCH_SIZE = 3 +INPUT_LENGTH = 5 +TARGET_LENGTH = 7 +VOCAB_SIZE = 9 + + class TransformerTest(tf.test.TestCase): - def _testTransformer(self, net): - batch_size = 3 - input_length = 5 - target_length = 7 - vocab_size = 9 - hparams = transformer.transformer_tiny() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + def getModel(self): + hparams = transformer.transformer_small() + p_hparams = problem_hparams.test_problem_hparams( + hparams, VOCAB_SIZE, VOCAB_SIZE) hparams.problems = [p_hparams] inputs = -1 + np.random.random_integers( - vocab_size, size=(batch_size, input_length, 1, 1)) + VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1)) targets = -1 + np.random.random_integers( - vocab_size, size=(batch_size, target_length, 1, 1)) + VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1)) + features = { + "inputs": tf.constant(inputs, dtype=tf.int32), + "targets": tf.constant(targets, dtype=tf.int32), + "target_space_id": tf.constant(1, dtype=tf.int32), + } + + return transformer.Transformer( + hparams, tf.contrib.learn.ModeKeys.INFER, p_hparams), features + + def testTransformer(self): + model, features = self.getModel() + shadred_logits, _ = model.model_fn(features) + logits = tf.concat(shadred_logits, 0) with self.test_session() as session: - features = { - "inputs": tf.constant(inputs, dtype=tf.int32), - "targets": tf.constant(targets, dtype=tf.int32), - "target_space_id": tf.constant(1, dtype=tf.int32), - } - model = net(hparams, tf.contrib.learn.ModeKeys.TRAIN, p_hparams) - shadred_logits, _ = model.model_fn(features) - logits = tf.concat(shadred_logits, 0) session.run(tf.global_variables_initializer()) res = session.run(logits) - self.assertEqual(res.shape, (batch_size, target_length, 1, 1, vocab_size)) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE)) - def testTransformer(self): - self._testTransformer(transformer.Transformer) + def testBeamDecodeVsGreedy(self): + model, features = self.getModel() + + decode_length = 20 + + greedy_result, _, _ = model._greedy_infer( + features, decode_length, last_position_only=True) + greedy_result = tf.squeeze(greedy_result, axis=[2, 3]) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + beam_res = model._beam_decode( + features, + decode_length, + beam_size=1, + top_beams=1, + last_position_only=True, + alpha=1.0) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + greedy_res, beam_res = session.run([greedy_result, beam_res]) + + self.assertEqual(beam_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(greedy_res, beam_res) if __name__ == "__main__": diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index 6a3f3afdf..fa6b3f397 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -26,7 +26,6 @@ from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_layers from tensor2tensor.models import transformer -from tensor2tensor.utils import expert_utils from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model @@ -85,37 +84,26 @@ def decompress_step(source, c, hparams, first_relu, name): return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size]) -def top_k_softmax(x, k): - """Calculate softmax(x), select top-k and rescale to sum to 1.""" - x = tf.nn.softmax(x) - top_x, _ = tf.nn.top_k(x, k=k+1) - min_top = tf.reduce_min(top_x, axis=-1, keep_dims=True) - x = tf.nn.relu((x - min_top) + 1e-12) - x /= tf.reduce_sum(x, axis=-1, keep_dims=True) - return x, tf.reduce_max(top_x, axis=-1) +def gumbel_sample(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = tf.random_uniform(shape, minval=0.00001, maxval=0.99998) + return -tf.log(-tf.log(uniform_samples)) -def top_k_experts(x, k, hparams): - x_shape = tf.shape(x) - x_flat = tf.reshape(x, [-1, x.get_shape().as_list()[-1]]) - is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN - gates, load = expert_utils.noisy_top_k_gating( - x_flat, hparams.v_size, is_training, k) - gates_shape = [x_shape[0], x_shape[1], x_shape[2], hparams.v_size] - gates = tf.reshape(gates, gates_shape) - load_loss = expert_utils.cv_squared(load) - return gates, load_loss - - -def dvae(x, k, hparams, name): +def dvae(x, hparams, name): with tf.variable_scope(name): m = tf.layers.dense(x, hparams.v_size, name="mask") - if k is None: - m = tf.nn.softmax(m) - kl = - tf.reduce_max(m, axis=-1) - else: - m, kl = top_k_softmax(m, k) - return m, 1.0 - tf.reduce_mean(kl) + logsm = tf.nn.log_softmax(m) + # Gumbel-softmax sample. + gumbel_samples = gumbel_sample(tf.shape(m)) + steps = hparams.kl_warmup_steps + gumbel_samples *= common_layers.inverse_exp_decay(steps) * 0.1 + temperature = 1.2 - common_layers.inverse_lin_decay(steps) + s = tf.nn.softmax((logsm + gumbel_samples) / temperature) + m = tf.nn.softmax(m) + kl = - tf.reduce_max(logsm, axis=-1) + tf.summary.histogram("max-log", tf.reshape(kl, [-1])) + return m, s, tf.reduce_mean(kl) def vae(x, hparams, name): @@ -130,6 +118,28 @@ def vae(x, hparams, name): return z, tf.reduce_mean(kl), mu, log_sigma +def nearest(x, means, hparams): + """Find the nearest means to elements in x.""" + x, means = tf.stop_gradient(x), tf.stop_gradient(means) + means = tf.nn.l2_normalize(means, dim=1) + x_flat = tf.reshape(x, [-1, hparams.hidden_size]) + # dist = tf.reduce_sum(tf.square(x_flat - tf.expand_dims(means, 0)), axis=2) + dist = - tf.matmul(x_flat, means, transpose_b=True) + _, nearest_idx = tf.nn.top_k(- dist, k=1) + nearest_hot = tf.one_hot(tf.squeeze(nearest_idx, axis=1), hparams.v_size) + nearest_hot = tf.reshape(nearest_hot, [tf.shape(x)[0], tf.shape(x)[1], + 1, hparams.v_size]) + return tf.stop_gradient(nearest_hot) + + +def kmeans(x, means, hparams, name): + with tf.variable_scope(name): + x_means_hot = nearest(x, means, hparams) + x_means = tf.gather(means, tf.argmax(x_means_hot, axis=-1)) + kl = tf.reduce_sum(tf.square(x - x_means), axis=-1) + return x_means_hot, x_means_hot, tf.reduce_mean(kl) * 100.0 + + def compress(x, c, hparams, name): """Compress.""" with tf.variable_scope(name): @@ -145,79 +155,112 @@ def compress(x, c, hparams, name): return cur -def mix(x1, x2, steps, min_prob=0.0, max_prob=1.0, mode="lin"): +def mix(x1, x2, steps, min_prob=0.0, max_prob=1.0, mode="lin", simple=False): + """Mix starting with x2, mixing mixing, going towards x1.""" if mode == "lin": - alpha_p = common_layers.inverse_lin_decay(steps) + 0.001 + alpha_p = common_layers.inverse_lin_decay(steps) else: - alpha_p = common_layers.inverse_exp_decay(steps) + 0.001 + alpha_p = common_layers.inverse_exp_decay(steps) alpha_p = alpha_p * (max_prob - min_prob) + min_prob + if simple: + return alpha_p * x1 + (1.0 - alpha_p) * x2 alpha = tf.random_uniform(tf.shape(x1)) alpha = tf.to_float(tf.less(alpha, alpha_p)) return alpha * x1 + (1.0 - alpha) * x2 -def vae_compress(x, c, hparams, compress_name, decompress_name, reuse=None): +def encode(x, x_space, hparams, name): + """Transformer preparations and encoder.""" + with tf.variable_scope(name): + (encoder_input, encoder_self_attention_bias, + ed) = transformer.transformer_prepare_encoder(x, x_space, hparams) + encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) + return transformer.transformer_encoder( + encoder_input, encoder_self_attention_bias, hparams), ed + + +def decode(cond_vec, cond_add, gold, c, ed, hparams): + """Transformer decoder.""" + drop_gold = tf.nn.dropout(gold, 1.0 - hparams.layer_prepostprocess_dropout) + decoder_input = common_layers.shift_left(drop_gold, pad_value=cond_vec) + if cond_add is not None: + decoder_input += cond_add + decoder_input = tf.squeeze(decoder_input, axis=2) + decoder_input = common_attention.add_timing_signal_1d(decoder_input) + bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1]) + if c is not None: + c = tf.squeeze(c, axis=2) + return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams) + + +def expand_batch(x, mul): + """Expand on batch by mul times.""" + cx = tf.expand_dims(x, axis=1) + x_shape = x.get_shape().as_list() + batch_mul = tf.to_int32(mul) + cx += tf.zeros([1, batch_mul, 1, 1, 1]) + mid_shape = [tf.shape(x)[2]] if len(x_shape) > 3 else [] + end_shape = [x_shape[-1]] if x_shape[-1] else [tf.shape(x)[-1]] + res_shape = [-1, tf.shape(x)[1]] + mid_shape + end_shape + return tf.reshape(cx, res_shape) + + +def vae_compress(x, c, ed, hparams, compress_name, decompress_name, reuse=None): """Compress, then VAE.""" - mix_k = 8 with tf.variable_scope(compress_name, reuse=reuse): cur = compress(x, None, hparams, "compress") # Convolve and ReLu to get state. cur = common_layers.conv_block( cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv") + cur = tf.nn.l2_normalize(cur, dim=3) + means = tf.get_variable("z_to_dense", [hparams.v_size, hparams.hidden_size]) # z, kl_loss, mu, log_sigma = vae(cur, hparams, name="vae") - z, kl_loss = dvae(cur, None, hparams, name="dvae") - z1, kl_loss1 = top_k_experts(cur, mix_k, hparams) - mu, log_sigma = None, None - - # Mix expert-selection and flat selection. - alpha_p = common_layers.inverse_lin_decay(60000) + 0.001 - z = alpha_p * z1 + (1 - alpha_p) * z - kl_loss += kl_loss1 + # z_true, z_sample, kl_loss = dvae(cur, hparams, name="dvae") + z_true, z_sample, kl_loss = kmeans(cur, means, hparams, name="kmeans") # Compress context. with tf.variable_scope(compress_name, reuse=reuse): compress_c = compress(c, None, hparams, "compress_context") - c_z = tf.layers.dense(compress_c, hparams.v_size, name="mask_context") + dec_c = decode(None, compress_c, cur, None, None, hparams) + c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context") reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits( - labels=z, logits=c_z) + labels=z_true, logits=c_z) # If not training, use the predicted z instead of the autoregressive one. - # if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN: - # z = mix(c_z, z, 50000, max_prob=0.3, mode="exp") - # z, _ = top_k_softmax(c_z, mix_k) + if hparams.mode == tf.contrib.learn.ModeKeys.INFER: + z = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size) with tf.variable_scope(decompress_name, reuse=reuse): # Decompress. - z = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense") + z_sample_flat = tf.reshape(z_sample, [-1, hparams.v_size]) + z = tf.matmul(z_sample_flat, means) + z = tf.reshape(z, [tf.shape(z_sample)[0], tf.shape(z_sample)[1], + 1, hparams.hidden_size]) # Leak at the beginning to help train. - z = mix(z, cur, 30000) + z = mix(z, cur, hparams.startup_steps) + # Dropout for better autoencoding. + z = tf.nn.dropout(z, keep_prob=0.9) + + # Decompress. + d = z for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 - z = residual_conv(z, 1, hparams, "decompress_rc_%d" % j) - z = decompress_step(z, c, hparams, i > 0, "decompress_step_%d" % j) - return z, kl_loss + 0.0001 * reconstruct_loss, mu, log_sigma - - -def encode(x, x_space, hparams, name): - """Transformer preparations and encoder.""" - with tf.variable_scope(name): - (encoder_input, encoder_self_attention_bias, - _) = transformer.transformer_prepare_encoder(x, x_space, hparams) - encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) - return transformer.transformer_encoder( - encoder_input, encoder_self_attention_bias, hparams) + d = residual_conv(d, 1, hparams, "decompress_rc_%d" % j) + d = decompress_step(d, c, hparams, i > 0, "decompress_step_%d" % j) + k = 2**hparams.num_compress_steps + z_batch = tf.reshape(z, [-1, 1, 1, hparams.hidden_size]) + x_batch = tf.reshape(x, [-1, k, 1, hparams.hidden_size]) + d_batch = tf.reshape(d, [-1, k, 1, hparams.hidden_size]) + # dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams) + c = expand_batch(c, tf.shape(x_batch)[0] / tf.shape(x)[0]) + ed = expand_batch(ed, tf.shape(x_batch)[0] / tf.shape(x)[0]) + dec_batch = decode(z_batch, d_batch, x_batch, c, ed, hparams) + z = tf.reshape(dec_batch, [-1, tf.shape(x)[1], 1, hparams.hidden_size]) -def dropmask(targets, targets_dropout_max, is_training): - if not is_training: - return targets - targets_drop_prob = tf.random_uniform([]) * targets_dropout_max - drop_mask = tf.random_uniform(tf.shape(targets)[:-1]) - drop_mask = tf.to_float(tf.less(drop_mask, targets_drop_prob)) - keep_mask = tf.expand_dims(1.0 - drop_mask, axis=2) - return targets * keep_mask + return z, kl_loss, reconstruct_loss def ffn(x, hparams, name): @@ -239,29 +282,16 @@ def vae_transformer_internal(inputs, targets, target_space, hparams): k = 2**hparams.num_compress_steps inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=k) - inputs = encode(inputs, target_space, hparams, "input_enc") + inputs, ed_bias = encode(inputs, target_space, hparams, "input_enc") # Compress and vae. - z, kl_loss, _, _ = vae_compress(tf.expand_dims(targets, axis=2), - tf.expand_dims(inputs, axis=2), - hparams, "vae_compress", "vae_decompress") - - # Join z with inputs, run decoder. - to_decode = common_layers.conv_block( - tf.concat([z, tf.expand_dims(inputs, axis=2)], axis=3), - hparams.hidden_size, [((1, 1), (1, 1))], name="join_z") - ret = encode(tf.squeeze(to_decode, axis=2), target_space, hparams, "dec") - - # For experiments with one-sided decoder: - # decoder_in = tf.squeeze(to_decode, axis=2) - # (decoder_input, decoder_self_attention_bias) = ( - # transformer.transformer_prepare_decoder(decoder_in, hparams)) - # ret = transformer.transformer_decoder( - # decoder_input, inputs, decoder_self_attention_bias, None, hparams) - - kl_loss *= common_layers.inverse_exp_decay(hparams.kl_warmup_steps) * 3.0 - losses = {"kl": kl_loss} - return tf.expand_dims(ret, axis=2), losses + z, kl, r = vae_compress(tf.expand_dims(targets, axis=2), + tf.expand_dims(inputs, axis=2), + ed_bias, hparams, "vae_compress", "vae_decompress") + kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.5)) + r *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 2.0)) + losses = {"kl": kl, "reconstruction": r} + return z, losses @registry.register_model @@ -296,7 +326,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, samples = tf.concat(sharded_samples, 0) # More steps. - how_many_more_steps = 20 + how_many_more_steps = 2 for _ in xrange(how_many_more_steps): with tf.variable_scope(tf.get_variable_scope(), reuse=True): features["targets"] = samples @@ -317,9 +347,10 @@ def transformer_vae_small(): hparams.batch_size = 2048 hparams.learning_rate_warmup_steps = 4000 hparams.add_hparam("z_size", 128) - hparams.add_hparam("v_size", 1024*8) + hparams.add_hparam("v_size", 1024*32) hparams.add_hparam("num_compress_steps", 4) - hparams.add_hparam("kl_warmup_steps", 50000) + hparams.add_hparam("kl_warmup_steps", 60000) + hparams.add_hparam("startup_steps", 30000) return hparams diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 03e7720b6..dbbd8e936 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -18,11 +18,15 @@ from __future__ import division from __future__ import print_function +import fractions import math import os +import random # Dependency imports +import numpy as np + import six from six.moves import zip # pylint: disable=redefined-builtin @@ -88,8 +92,7 @@ def examples_reader(data_sources, by default (if this is None), we decode all items. Returns: - A dictionary mapping each data_field to a corresponding 1D int64 tensor - read from the created Dataset. + A tf.contrib.data.Dataset of dict """ def decode_record(record): @@ -113,18 +116,17 @@ def decode_record(record): return dict(zip(decode_items, decoded)) with tf.name_scope("examples_in"): - # Read serialized examples using slim parallel_reader. data_files = tf.contrib.slim.parallel_reader.get_data_files(data_sources) - num_readers = min(4 if training else 1, len(data_files)) - _, example_serialized = tf.contrib.slim.parallel_reader.parallel_read( - data_sources, - tf.TFRecordReader, - num_epochs=None if training else 1, - shuffle=training, - capacity=2 * capacity, - min_after_dequeue=capacity, - num_readers=num_readers) - return decode_record(example_serialized) + if training: + random.shuffle(data_files) + dataset = tf.contrib.data.TFRecordDataset(data_files) + num_threads = min(4 if training else 1, len(data_files)) + dataset = dataset.map(decode_record, num_threads=num_threads) + if training: + dataset = dataset.shuffle(capacity) + # Loop inifinitely if training, just once otherwise + dataset = dataset.repeat(None if training else 1) + return dataset def preprocessing(examples, data_file_pattern): @@ -132,21 +134,15 @@ def preprocessing(examples, data_file_pattern): # This function is for obsolete problems only, as we're porting them # all to the Problem class and its preprocess_examples method. Don't add. if "image" in data_file_pattern: + def resize(img, size): - return tf.to_int64(tf.image.resize_images( - img, [size, size], tf.image.ResizeMethod.AREA)) + return tf.to_int64( + tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA)) if "img2img" in data_file_pattern: inputs = examples["inputs"] examples["inputs"] = resize(inputs, 16) examples["targets"] = resize(inputs, 64) - elif "image_celeba" in data_file_pattern: - inputs = examples["inputs"] - # Remove boundaries in CelebA images. Remove 40 pixels each side - # vertically and 20 pixels each side horizontally. - inputs = tf.image.crop_to_bounding_box(inputs, 40, 20, 218-80, 178-40) - examples["inputs"] = resize(inputs, 8) - examples["targets"] = resize(inputs, 32) elif "audio" in data_file_pattern: # Reshape audio to proper shape sample_count = tf.to_int32(examples.pop("audio/sample_count")) @@ -218,8 +214,11 @@ def default_example_reading_spec(data_file_pattern): return data_fields, data_items_to_decoders -def input_pipeline(problem, data_file_pattern, capacity, mode, hparams): - """Input pipeline, returns a dictionary of tensors from queues.""" +def read_examples(problem, + data_file_pattern, + capacity, + mode=tf.contrib.learn.ModeKeys.TRAIN): + """Create Dataset of Example for problem and data_file_pattern.""" if problem is None: data_fields, data_items_to_decoders = default_example_reading_spec( data_file_pattern) @@ -230,73 +229,170 @@ def input_pipeline(problem, data_file_pattern, capacity, mode, hparams): # Create placeholders for input, rather than reading data from disk. return feature_placeholders(data_fields) - examples = examples_reader( + is_training = mode == tf.contrib.learn.ModeKeys.TRAIN + dataset = examples_reader( [data_file_pattern], data_fields, - training=(mode == tf.contrib.learn.ModeKeys.TRAIN), + training=is_training, capacity=capacity, data_items_to_decoders=data_items_to_decoders) - - if problem is None: - examples = preprocess_examples_common(examples, hparams) - examples = preprocessing(examples, data_file_pattern) - else: - examples = problem.preprocess_examples(examples, mode, hparams) - - # We do not want int64s as they are not supported on GPUs. - examples = cast_int64_to_int32(examples) - - return examples + return dataset -def batch_examples(examples, batching_scheme): - """Given a queue of examples, create batches of examples with similar lengths. - - We assume that examples is a dictionary with string keys and tensor values, - possibly coming from a queue, e.g., constructed by examples_reader above. - Each tensor in examples is assumed to be 1D. We will put tensors of similar - length into batches togeter. We return a dictionary with the same keys as - examples, and with values being batches of size batch_size. If elements have - different lengths, they are padded with 0s. This function is based on - tf.contrib.training.bucket_by_sequence_length so see there for details. - - For example, if examples is a queue containing [1, 2, 3] and [4], then - this function with batch_size=2 will return a batch [[1, 2, 3], [4, 0, 0]]. +def input_pipeline(problem, data_file_pattern, capacity, mode, hparams, + batching_scheme): + """Input pipeline, returns a dictionary of batched and padded tensors. Args: - examples: a dictionary with string keys and 1D tensor values. + problem: Problem instance for which to build the input pipeline. + data_file_pattern: file pattern for input files. + capacity: int, data pipeline buffer capacity. + mode: tf.contrib.learn.ModeKeys entry. + hparams: an HParams object. batching_scheme: a dictionary containing "boundaries": a list of integers for the boundaries that will be - used for bucketing; see tf.contrib.training.bucket_by_sequence_length - for more details. + used for bucketing; see bucket_by_sequence_length for more details. "batch_sizes": a list of batch sizes corresponding to the buckets "max_length": an integer. We drop sequences which are longer. Returns: - A dictionary with the same keys as examples and with values being batches - of examples padded with 0s, i.e., [batch_size x length] tensors. + dict + """ + is_training = mode == tf.contrib.learn.ModeKeys.TRAIN + num_threads = 4 if is_training else 1 + + with tf.name_scope("input_pipeline"): + dataset = read_examples(problem, data_file_pattern, capacity, mode=mode) + dataset = dataset.map( + lambda ex: _preprocess(ex, problem, data_file_pattern, hparams, mode), + num_threads=num_threads) + dataset = dataset.filter( + lambda ex: _example_too_big(ex, batching_scheme["max_length"])) + + dataset = bucket_by_sequence_length(dataset, _example_length, + batching_scheme["boundaries"], + batching_scheme["batch_sizes"]) + max_batch_size = max(batching_scheme["batch_sizes"]) + # We reshuffle the batches to prevent many long-sequence batches at once. + dataset = dataset.shuffle(max_batch_size * 3) + batched_examples = dataset.make_one_shot_iterator().get_next() + return batched_examples + + +def _preprocess(example, problem, data_file_pattern, hparams, mode): + """Preprocessing for example.""" + if problem is None: + example = preprocess_examples_common(example, hparams) + example = preprocessing(example, data_file_pattern) + else: + example = problem.preprocess_examples(example, mode, hparams) + + # We do not want int64s as they are not supported on GPUs. + example = cast_int64_to_int32(example) + + return example + + +def _example_length(example): + length = 0 + # Length of the example is the maximum length of the feature lengths + for v in example.values(): + # For images the sequence length is the size of the spatial dimensions. + feature_length = (tf.shape(v)[0] if len(v.get_shape()) < 3 else + tf.shape(v)[0] * tf.shape(v)[1]) + length = tf.maximum(length, feature_length) + return length + + +def _example_too_big(example, max_length): + return tf.less_equal(_example_length(example), max_length) + + +def _lcm(l): + """Least common multiple of integers in a list.""" + if not l: + raise ValueError("LCD of an empty list.") + if len(l) == 1: + return l[0] + x = l[0] + y = _lcm(l[1:]) + return x * y // fractions.gcd(x, y) + + +def _closest_small_primes(x): + """Closest number to x which has only 2, 3, 5 as prime factors, 3,5 once.""" + assert x > 0 + def is_small_primes(x, covered3, covered5): + if x % 2 == 0: + return is_small_primes(x // 2, covered3, covered5) + if x % 3 == 0 and not covered3: + return is_small_primes(x // 3, True, covered5) + if x % 5 == 0 and not covered5: + return is_small_primes(x // 5, covered3, True) + return x == 1 + for i in xrange(x): + if is_small_primes(x - i, False, False): + return x - i + # We search for higher numbers too, but only 8 of them to not increase much. + if i < 9 and is_small_primes(x + i, False, False): + return x + i + + +def bucket_by_sequence_length(dataset, example_length_fn, bucket_boundaries, + bucket_batch_sizes): + """Bucket entries in dataset by length. + + Args: + dataset: Dataset of dict. + example_length_fn: function from example to int, determines the length of + the example, which will determine the bucket it goes into. + bucket_boundaries: list, boundaries of the buckets. + bucket_batch_sizes: list, batch size per bucket. + + Returns: + Dataset of padded and batched examples. """ - with tf.name_scope("batch_examples"): - # The queue to bucket on will be chosen based on maximum length. - max_length = 0 - for v in examples.values(): - # For images the sequence length is the size of the spatial dimensions. - sequence_length = (tf.shape(v)[0] if len(v.get_shape()) < 3 else - tf.shape(v)[0] * tf.shape(v)[1]) - max_length = tf.maximum(max_length, sequence_length) - (_, outputs) = tf.contrib.training.bucket_by_sequence_length( - max_length, - examples, - batching_scheme["batch_sizes"], - [b + 1 for b in batching_scheme["boundaries"]], - capacity=2, # Number of full batches to store, we don't need many. - bucket_capacities=[2 * b for b in batching_scheme["batch_sizes"]], - dynamic_pad=True, - keep_input=(max_length <= batching_scheme["max_length"])) - return outputs - - -def bucket_boundaries(max_length, min_length=8, mantissa_bits=2): + # Since the Datasets API only allows a single constant for window_size, + # and it needs divide all bucket_batch_sizes, we first make sure they only + # have a few primes in them so that their LCM doesn't explode quickly. + # TODO(lukaszkaiser): remove this adjustment when Dataset API improves. + bucket_batch_sizes1 = [_closest_small_primes(b) for b in bucket_batch_sizes] + tf.logging.info("Corrected bucket_batch_sizes from %s to %s." + % (str(bucket_batch_sizes), str(bucket_batch_sizes1))) + bucket_batch_sizes = bucket_batch_sizes1 + with tf.name_scope("bucket_by_seq_length"): + + def example_to_bucket_id(example): + """Return int64 id of the length bucket for this example.""" + seq_length = example_length_fn(example) + + boundaries = list(bucket_boundaries) + buckets_min = [np.iinfo(np.int32).min] + boundaries + buckets_max = boundaries + [np.iinfo(np.int32).max] + conditions_c = tf.logical_and( + tf.less_equal(buckets_min, seq_length), + tf.less(seq_length, buckets_max)) + bucket_id = tf.reduce_min(tf.where(conditions_c)) + + return bucket_id + + def batching_fn(bucket_id, grouped_dataset): + batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64) + batch_size = batch_sizes[bucket_id] + + # Pad each dimension of each feature so that they match. + padded_shapes = dict( + [(name, [None] * len(shape)) + for name, shape in grouped_dataset.output_shapes.items()]) + return grouped_dataset.padded_batch(batch_size, padded_shapes) + + window_size = _lcm(bucket_batch_sizes) + dataset = dataset.group_by_window(example_to_bucket_id, batching_fn, + window_size) + return dataset + + +def _bucket_boundaries(max_length, min_length=8, mantissa_bits=2): """A default set of length-bucket boundaries.""" x = min_length boundaries = [] @@ -306,39 +402,46 @@ def bucket_boundaries(max_length, min_length=8, mantissa_bits=2): return boundaries -def hparams_to_batching_scheme(hparams, - drop_long_sequences=False, - shard_multiplier=1, - length_multiplier=1): +def _batching_scheme(batch_size=16 * 256, + max_length=None, + batching_mantissa_bits=1, + drop_long_sequences=False, + shard_multiplier=1, + length_multiplier=1): """A batching scheme based on model hyperparameters. Every batch containins a number of sequences divisible by `shard_multiplier`. - If `drop_long_sequences` is True, then sequences longer than - `hparams.batch_size` are dropped. This prevents generating batches with - more than the usual number of tokens, which can cause out-of-memory errors. - Args: - hparams: a hyperparameters. - drop_long_sequences: a boolean. + batch_size: int, total number of tokens in a batch. + max_length: int, sequences longer than this will be skipped. Defaults to + batch_size. + batching_mantissa_bits: int, ??. + drop_long_sequences: bool, if True, then sequences longer than + `max_length` are dropped. This prevents generating batches with + more than the usual number of tokens, which can cause out-of-memory + errors. shard_multiplier: an integer increasing the batch_size to suit splitting across datashards. length_multiplier: an integer multiplier that is used to increase the batch sizes and sequence length tolerance. Returns: - a dictionary + A dictionary with parameters that can be passed to input_pipeline: + * boundaries: list of bucket boundaries + * batch_sizes: list of batch sizes for each length bucket + * max_length: int, maximum length of an example """ - max_length = hparams.max_length or hparams.batch_size - boundaries = bucket_boundaries( - max_length, mantissa_bits=hparams.batching_mantissa_bits) + max_length = max_length or batch_size + boundaries = _bucket_boundaries( + max_length, mantissa_bits=batching_mantissa_bits) + boundaries = [boundary * length_multiplier for boundary in boundaries] + max_length *= length_multiplier + batch_sizes = [ - max(1, hparams.batch_size // length) + max(1, batch_size // length) * shard_multiplier for length in boundaries + [max_length] ] - batch_sizes = [b * shard_multiplier for b in batch_sizes] - max_length *= length_multiplier - boundaries = [boundary * length_multiplier for boundary in boundaries] return { "boundaries": boundaries, "batch_sizes": batch_sizes, @@ -346,6 +449,20 @@ def hparams_to_batching_scheme(hparams, } +def hparams_to_batching_scheme(hparams, + drop_long_sequences=False, + shard_multiplier=1, + length_multiplier=1): + """Wrapper around _batching_scheme with hparams.""" + return _batching_scheme( + max_length=hparams.max_length, + batch_size=hparams.batch_size, + batching_mantissa_bits=hparams.batching_mantissa_bits, + drop_long_sequences=drop_long_sequences, + shard_multiplier=shard_multiplier, + length_multiplier=length_multiplier) + + def constant_batching_scheme(constant_batch_size_in_sequences): """A batching scheme with constant batch size. @@ -355,7 +472,7 @@ def constant_batching_scheme(constant_batch_size_in_sequences): Returns: a dictionary """ - boundaries = bucket_boundaries(1024) + boundaries = _bucket_boundaries(1024) batch_sizes = [constant_batch_size_in_sequences] * (1 + len(boundaries)) return { "boundaries": boundaries, diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py index ea98da06d..318fb1cab 100644 --- a/tensor2tensor/utils/data_reader_test.py +++ b/tensor2tensor/utils/data_reader_test.py @@ -28,123 +28,223 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem as problem_mod from tensor2tensor.utils import data_reader +from tensor2tensor.utils import registry import tensorflow as tf -class DataReaderTest(tf.test.TestCase): +@registry.register_problem +class TestProblem(problem_mod.Problem): - def testExamplesQueue(self): - tf.set_random_seed(1) - tmp_dir = self.get_temp_dir() - (_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir) - tmp_file_name = os.path.basename(tmp_file_path) - - # Generate a file with 100 examples. - def test_generator(): - for i in xrange(100): - yield {"inputs": [i], "targets": [i], "floats": [i + 0.5]} - - filenames = generator_utils.train_data_filenames(tmp_file_name, tmp_dir, 1) - generator_utils.generate_files(test_generator(), filenames) - self.assertTrue(tf.gfile.Exists(tmp_file_path + "-train-00000-of-00001")) - - examples_train = data_reader.examples_reader( - [tmp_file_path + "*"], { - "inputs": tf.VarLenFeature(tf.int64), - "targets": tf.VarLenFeature(tf.int64) - }, - training=True) - examples_eval = data_reader.examples_reader( - [tmp_file_path + "*"], { - "inputs": tf.VarLenFeature(tf.int64), - "targets": tf.VarLenFeature(tf.int64), - "floats": tf.VarLenFeature(tf.float32) - }, - training=False) - with tf.train.MonitoredSession() as session: - # Evaluation data comes in the same order as in the file, check 10. - for i in xrange(10): - examples = session.run(examples_eval) - self.assertEqual(len(examples["inputs"]), 1) - self.assertEqual(len(examples["targets"]), 1) - self.assertEqual(examples["inputs"][0], i) - self.assertEqual(examples["targets"][0], i) - self.assertEqual(examples["floats"][0], i + 0.5) - # Training data is shuffled. - is_shuffled = False - for i in xrange(10): - examples = session.run(examples_train) - self.assertEqual(len(examples["inputs"]), 1) - self.assertEqual(len(examples["targets"]), 1) - self.assertEqual(examples["inputs"][0], examples["targets"][0]) - if examples["inputs"][0] != i: - is_shuffled = True - self.assertTrue(is_shuffled) - - # Clean up. - os.remove(tmp_file_path + "-train-00000-of-00001") - os.remove(tmp_file_path) - - # TODO(rsepassi): fix and reenable test - def _testBatchExamples(self): - tf.set_random_seed(1) - tmp_dir = self.get_temp_dir() - (_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir) - tmp_file_name = os.path.basename(tmp_file_path) + def generator(self, data_dir, tmp_dir, is_training): + for i in xrange(30): + yield {"inputs": [i] * (i + 1), "targets": [i], "floats": [i + 0.5]} - # Generate a file with 100 examples, n-th example of length n + 1. - def test_generator(): - for i in xrange(100): - yield {"inputs": [i + 1 for _ in xrange(i + 1)], "targets": [i + 1]} + def generate_data(self, data_dir, tmp_dir, task_id=-1): + train_paths = self.training_filepaths(data_dir, 1, shuffled=True) + dev_paths = self.dev_filepaths(data_dir, 1, shuffled=True) + generator_utils.generate_files( + self.generator(data_dir, tmp_dir, True), train_paths) + generator_utils.generate_files( + self.generator(data_dir, tmp_dir, False), dev_paths) - filenames = generator_utils.train_data_filenames(tmp_file_name, tmp_dir, 1) - generator_utils.generate_files(test_generator(), filenames) - self.assertTrue(tf.gfile.Exists(tmp_file_path + "-train-00000-of-00001")) + def hparams(self, defaults, model_hparams): + pass - examples_train = data_reader.examples_reader([tmp_file_path + "*"], { + def example_reading_spec(self): + data_fields = { "inputs": tf.VarLenFeature(tf.int64), - "targets": tf.VarLenFeature(tf.int64) - }, True) - batch_train = data_reader.batch_examples(examples_train, 4) - examples_eval = data_reader.examples_reader([tmp_file_path + "*"], { - "inputs": tf.VarLenFeature(tf.int64), - "targets": tf.VarLenFeature(tf.int64) - }, False) - batch_eval = data_reader.batch_examples(examples_eval, 2) - session, coord = tf.Session(), tf.train.Coordinator() - with session.as_default(): - tf.train.start_queue_runners(coord=coord) - - # Evaluation data comes in the same order as in the file. - # The first batch will be inputs=[[1, 0], [2, 2]], targets=[[1], [2]]. - examples = session.run(batch_eval) - self.assertAllClose(examples["inputs"], np.array([[1, 0], [2, 2]])) - self.assertAllClose(examples["targets"], np.array([[1], [2]])) - # Check the second batch too. - examples = session.run(batch_eval) - self.assertAllClose(examples["inputs"], - np.array([[3, 3, 3, 0], [4, 4, 4, 4]])) - self.assertAllClose(examples["targets"], np.array([[3], [4]])) - - # Training data is shuffled but shouldn't have too many pads. + "targets": tf.VarLenFeature(tf.int64), + "floats": tf.VarLenFeature(tf.float32), + } + data_items_to_decoders = None + return (data_fields, data_items_to_decoders) + + def preprocess_examples(self, examples, unused_mode, unused_hparams): + examples["new_field"] = tf.constant([42.42]) + return examples + + +def generate_test_data(problem, tmp_dir): + problem.generate_data(tmp_dir, tmp_dir) + filepatterns = data_reader.get_data_filepatterns( + problem.name, tmp_dir, tf.contrib.learn.ModeKeys.TRAIN) + assert tf.gfile.Glob(filepatterns[0]) + return filepatterns + + +class DataReaderTest(tf.test.TestCase): + + @classmethod + def setUpClass(cls): + tf.set_random_seed(1) + cls.problem = registry.problem("test_problem") + cls.filepatterns = generate_test_data(cls.problem, tempfile.gettempdir()) + + @classmethod + def tearDownClass(cls): + # Clean up files + for fp in cls.filepatterns: + files = tf.gfile.Glob(fp) + for f in files: + os.remove(f) + + def testBasicExampleReading(self): + dataset = data_reader.read_examples(self.problem, self.filepatterns[0], 32) + examples = dataset.make_one_shot_iterator().get_next() + with tf.train.MonitoredSession() as sess: + # Check that there are multiple examples that have the right fields of the + # right type (lists of int/float). for _ in xrange(10): - examples = session.run(batch_train) - inputs = examples["inputs"] - # Only 3 out of 4 examples in a batch have padding zeros at all. - pad_per_example = (inputs.size - np.count_nonzero(inputs)) // 3 - # Default bucketing is in steps of 8 until 64 and 32 later. - if int(max(examples["targets"])) < 64: - self.assertLess(pad_per_example, 8) - else: - self.assertLess(pad_per_example, 32) - - # Clean up. - coord.request_stop() - coord.join() - os.remove(tmp_file_path + "-train-00000-of-00001") - os.remove(tmp_file_path) + ex_val = sess.run(examples) + inputs, targets, floats = (ex_val["inputs"], ex_val["targets"], + ex_val["floats"]) + self.assertEqual(np.int64, inputs.dtype) + self.assertEqual(np.int64, targets.dtype) + self.assertEqual(np.float32, floats.dtype) + for field in [inputs, targets, floats]: + self.assertGreater(len(field), 0) + + def testTrainEvalBehavior(self): + train_dataset = data_reader.read_examples(self.problem, + self.filepatterns[0], 16) + train_examples = train_dataset.make_one_shot_iterator().get_next() + eval_dataset = data_reader.read_examples( + self.problem, + self.filepatterns[0], + 16, + mode=tf.contrib.learn.ModeKeys.EVAL) + eval_examples = eval_dataset.make_one_shot_iterator().get_next() + + eval_idxs = [] + with tf.train.MonitoredSession() as sess: + # Train should be shuffled and run through infinitely + for i in xrange(30): + self.assertNotEqual(i, sess.run(train_examples)["inputs"][0]) + + # Eval should not be shuffled and only run through once + for i in xrange(30): + self.assertEqual(i, sess.run(eval_examples)["inputs"][0]) + eval_idxs.append(i) + + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(eval_examples) + # Should never run because above line should error + eval_idxs.append(30) + + # Ensuring that the above exception handler actually ran and we didn't + # exit the MonitoredSession context. + eval_idxs.append(-1) + + self.assertAllEqual(list(range(30)) + [-1], eval_idxs) + + def testPreprocess(self): + dataset = data_reader.read_examples(self.problem, self.filepatterns[0], 32) + examples = dataset.make_one_shot_iterator().get_next() + examples = data_reader._preprocess(examples, self.problem, None, None, None) + with tf.train.MonitoredSession() as sess: + ex_val = sess.run(examples) + # problem.preprocess_examples has been run + self.assertAllClose([42.42], ex_val["new_field"]) + + # int64 has been cast to int32 + self.assertEqual(np.int32, ex_val["inputs"].dtype) + self.assertEqual(np.int32, ex_val["targets"].dtype) + self.assertEqual(np.float32, ex_val["floats"].dtype) + + def testLengthFilter(self): + max_len = 15 + dataset = data_reader.read_examples(self.problem, self.filepatterns[0], 32) + dataset = dataset.filter( + lambda ex: data_reader._example_too_big(ex, max_len)) + examples = dataset.make_one_shot_iterator().get_next() + with tf.train.MonitoredSession() as sess: + ex_lens = [] + for _ in xrange(max_len): + ex_lens.append(len(sess.run(examples)["inputs"])) + + self.assertAllEqual(list(range(1, max_len + 1)), sorted(ex_lens)) + + def testBatchingSchemeMaxLength(self): + scheme = data_reader._batching_scheme( + batch_size=20, max_length=None, drop_long_sequences=False) + self.assertGreater(scheme["max_length"], 10000) + + scheme = data_reader._batching_scheme( + batch_size=20, max_length=None, drop_long_sequences=True) + self.assertEqual(scheme["max_length"], 20) + + scheme = data_reader._batching_scheme( + batch_size=20, max_length=15, drop_long_sequences=True) + self.assertEqual(scheme["max_length"], 15) + + scheme = data_reader._batching_scheme( + batch_size=20, max_length=15, drop_long_sequences=False) + self.assertGreater(scheme["max_length"], 10000) + + def testBatchingSchemeBuckets(self): + scheme = data_reader._batching_scheme(batch_size=128) + boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"] + self.assertEqual(len(boundaries), len(batch_sizes) - 1) + expected_boundaries = [8, 12, 16, 24, 32, 48, 64, 96] + self.assertEqual(expected_boundaries, boundaries) + expected_batch_sizes = [16, 10, 8, 5, 4, 2, 2, 1, 1] + self.assertEqual(expected_batch_sizes, batch_sizes) + + scheme = data_reader._batching_scheme(batch_size=128, shard_multiplier=2) + boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"] + self.assertAllEqual([bs * 2 for bs in expected_batch_sizes], batch_sizes) + self.assertEqual(expected_boundaries, boundaries) + + scheme = data_reader._batching_scheme(batch_size=128, length_multiplier=2) + boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"] + self.assertAllEqual([b * 2 for b in expected_boundaries], boundaries) + self.assertEqual([max(1, bs // 2) + for bs in expected_batch_sizes], batch_sizes) + + def testBucketBySeqLength(self): + + def example_len(ex): + return tf.shape(ex["inputs"])[0] + + boundaries = [10, 20, 30] + batch_sizes = [10, 8, 4, 2] + + dataset = data_reader.read_examples( + self.problem, + self.filepatterns[0], + 32, + mode=tf.contrib.learn.ModeKeys.EVAL) + dataset = data_reader.bucket_by_sequence_length(dataset, example_len, + boundaries, batch_sizes) + batch = dataset.make_one_shot_iterator().get_next() + + input_vals = [] + obs_batch_sizes = [] + with tf.train.MonitoredSession() as sess: + # Until OutOfRangeError + while True: + batch_val = sess.run(batch) + batch_inputs = batch_val["inputs"] + batch_size, max_len = batch_inputs.shape + obs_batch_sizes.append(batch_size) + for inputs in batch_inputs: + input_val = inputs[0] + input_vals.append(input_val) + # The inputs were constructed such that they were repeated value+1 + # times (i.e. if the inputs value is 7, the example has 7 repeated 8 + # times). + repeat = input_val + 1 + # Check padding + self.assertAllEqual([input_val] * repeat + [0] * (max_len - repeat), + inputs) + + # Check that all inputs came through + self.assertEqual(list(range(30)), sorted(input_vals)) + # Check that we saw variable batch size + self.assertTrue(len(set(obs_batch_sizes)) > 1) if __name__ == "__main__": diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index 4ba8dc71a..2e430a204 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -74,22 +74,19 @@ def log_fn(inputs, decoded_outputs = " ".join(map(str, outputs.flatten())) decoded_targets = " ".join(map(str, targets.flatten())) else: - decoded_outputs = targets_vocab.decode( - _save_until_eos(outputs.flatten())) - decoded_targets = targets_vocab.decode( - _save_until_eos(targets.flatten())) + decoded_outputs = " ".join(map( + str, targets_vocab.decode(_save_until_eos(outputs.flatten())))) + decoded_targets = " ".join(map( + str, targets_vocab.decode(_save_until_eos(targets.flatten())))) tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs) tf.logging.info("Inference results TARGET: %s" % decoded_targets) - if FLAGS.decode_to_file: - output_filepath = FLAGS.decode_to_file + ".outputs." + problem - output_file = tf.gfile.Open(output_filepath, "a") - output_file.write(decoded_outputs + "\n") - target_filepath = FLAGS.decode_to_file + ".targets." + problem - target_file = tf.gfile.Open(target_filepath, "a") - target_file.write(decoded_targets + "\n") + return decoded_outputs, decoded_targets + result_iter = estimator.predict(input_fn=infer_input_fn, as_iterable=True) count = 0 + agg_outputs = [] + agg_targets = [] for result in result_iter: # predictions from the test input. We use it to log inputs and decodes. inputs = result["inputs"] @@ -99,13 +96,25 @@ def log_fn(inputs, output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0) for k, beam in enumerate(output_beams): tf.logging.info("BEAM %d:" % k) - log_fn(inputs, targets, beam, problem, count) + o, t = log_fn(inputs, targets, beam, problem, count) + agg_outputs.append(o) + agg_targets.append(t) else: - log_fn(inputs, targets, outputs, problem, count) + o, t = log_fn(inputs, targets, outputs, problem, count) + agg_outputs.append(o) + agg_targets.append(t) count += 1 if FLAGS.decode_num_samples != -1 and count >= FLAGS.decode_num_samples: break + if FLAGS.decode_to_file: + output_filepath = FLAGS.decode_to_file + ".outputs." + problem + output_file = tf.gfile.Open(output_filepath, "w") + target_filepath = FLAGS.decode_to_file + ".targets." + problem + target_file = tf.gfile.Open(target_filepath, "w") + for o, t in zip(agg_outputs, agg_targets): + output_file.write(str(o)+"\n") + target_file.write(str(t)+"\n") tf.logging.info("Completed inference on %d samples." % count) diff --git a/tensor2tensor/utils/diet.py b/tensor2tensor/utils/diet.py new file mode 100644 index 000000000..4ff44de5b --- /dev/null +++ b/tensor2tensor/utils/diet.py @@ -0,0 +1,360 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Diet variables are much more memory-efficient than regular variables. + +Using diet variables, we can reduce memory overhead per parameter from +16 bytes to 2 bytes, allowing for up to 4B parameters per GPU. + +Functions that build subgraphs with variables can be made to use diet variables +by using the fn_with_diet_vars decorator. +""" + +from collections import defaultdict +import copy +import math +# Dependency imports +from tensor2tensor.layers import common_layers +import tensorflow as tf + + +def diet_adam_optimizer_params(): + """Default hyperparameters for a DietAdamOptimizer. + + Returns: + a hyperparameters object. + """ + return tf.contrib.training.HParams( + quantize=int(True), # use 16-bit fixed-point + quantization_scale=10.0 / tf.int16.max, + optimizer="DietAdam", + learning_rate=1.0, + learning_rate_warmup_steps=2000, + learning_rate_decay_scheme="noam", # "noam" or "none" + epsilon=1e-10, + beta1=0.0, # we can save memory if beta1=0 + beta2=0.98, + factored_second_moment_accumulator=int(True), # this saves memory + ) + + +def diet_expert(x, hidden_size, params): + """A two-layer feed-forward network with relu activation on hidden layer. + + Uses diet variables. + Recompuets hidden layer on backprop to save activation memory. + + Args: + x: a Tensor with shape [batch, io_size] + hidden_size: an integer + params: a diet variable HParams object. + + Returns: + a Tensor with shape [batch, io_size] + """ + + @fn_with_diet_vars(params) + def diet_expert_internal(x): + dim = x.get_shape().as_list()[-1] + h = tf.layers.dense(x, hidden_size, activation=tf.nn.relu, use_bias=False) + y = tf.layers.dense(h, dim, use_bias=False) + y *= tf.rsqrt(tf.to_float(dim * hidden_size)) + return y + + return diet_expert_internal(x) + + +class DietVariableOptimizer(object): + """Base class for Diet variable optimizers.""" + + def __init__(self, params): + self._params = params + self._global_step = tf.train.get_or_create_global_step() + + @property + def params(self): + return self._params + + @property + def global_step(self): + return self._global_step + + def create_slots(self, var): + raise NotImplementedError() + + def update_variable(self, var, grad_var): + raise NotImplementedError() + + +class DietAdamOptimizer(DietVariableOptimizer): + """A memory efficient optimizer for memory-efficient variables. + + We employ the following techniques: + - 16-bit fixed-point quantization + - inline updates during backprop, instead of through the optimizer. This + keeps the gradients from staying around in memory. + - momentum is optional - saves a slot if it is off (beta1=0.0). + - "factored second-moment accumulator" + (keep row-wise and col-wise averages instead of full accumulator) + - tighter control over operation ordering to make sure that only a small + portion of the decompressed variables and of the variable gradients + are resident in memory at any given time. + + All together these techniques reduce the memory footprint per parameter to + a little over 2 bytes, allowing for roughly 4B parameters per GPU. This is + roughly an 8x improvement over the naive version. + + Usage: + + Diet variables should be created with the + DietAdamOptimizer.get_variable() method. The resulting variables + have extra fields pointing to the otpimizer and to the accumulator + slots. + + The variable is kept in quantized form, so you need to call + var.optimizer.dequantize(var) to get the value. + + The variables are created with trainable=False, so that they will + not be optimized by an ordinary optimizer. Instead, the user is + responsible for making sure that var.optimizer.update(var, grad) is + called during backprop. The reason for this inline update is to + avoid keeping around the gradients for all variables at once. This + is done with the clever use of defuns and control dependencies. See + diet_expert() for an example of how all of this is done. + + To facilitate fixed-point quantization and to make it easier to + choose a learning rate, all varaibles are initialized with unit + normal initialization. If you want smaller values, downscale on the + outside. + """ + + def create_slots(self, var): + """Create the factorized Adam accumulators for diet variables.""" + params = self.params + shape = var.get_shape().as_list() + + if not hasattr(params, "slots"): + params.slots = defaultdict(dict) + + name = var.op.name + slots = params.slots[name] + + if params.factored_second_moment_accumulator and len(shape) == 2: + slots["adam_vr"] = tf.get_variable( + name + "_adam_vr", [shape[0], 1], + trainable=False, + initializer=tf.zeros_initializer()) + slots["adam_vc"] = tf.get_variable( + name + "_adam_vc", [1, shape[1]], + trainable=False, + initializer=tf.zeros_initializer()) + else: + slots["adam_v"] = tf.get_variable( + name + "_adam_v", + shape, + trainable=False, + initializer=tf.zeros_initializer()) + if params.beta1 != 0.0: + slots["adam_m"] = tf.get_variable( + name + "_adam_m", + shape, + trainable=False, + initializer=tf.zeros_initializer()) + + def update_variable(self, var, grad_var): + """Update the variable and its slots.""" + params = self.params + global_step = tf.to_float(self.global_step) + 1 + + # compute learning rate + lrate = params.learning_rate + if params.learning_rate_decay_scheme == "noam": + lrate *= tf.minimum(global_step * params.learning_rate_warmup_steps**-1.5, + global_step**-0.5) + else: + assert params.learning_rate_decay_scheme == "none" + lrate *= tf.minumum(global_step / params.learning_rate_warmup_steps, 1.0) + + # compute adjustment due to second moment + slots = params.slots[var.op.name] + grad_squared = tf.square(grad_var) + beta2_pow = tf.pow(params.beta2, global_step) + if params.factored_second_moment_accumulator and len(var.shape) == 2: + vr_update = tf.assign(slots["adam_vr"], slots["adam_vr"] * params.beta2 + + tf.reduce_mean(grad_squared, 1, keep_dims=True) * + (1.0 - params.beta2)) + vc_update = tf.assign(slots["adam_vc"], slots["adam_vc"] * params.beta2 + + tf.reduce_mean(grad_squared, 0, keep_dims=True) * + (1.0 - params.beta2)) + with tf.control_dependencies([vr_update, vc_update]): + vr = tf.sqrt(slots["adam_vr"] / (1.0 - beta2_pow)) + params.epsilon + vc = tf.sqrt(slots["adam_vc"] / (1.0 - beta2_pow)) + params.epsilon + vc /= tf.reduce_mean(vc) + denom = vr * vc + else: + v_update = tf.assign(slots["adam_v"], + slots["adam_v"] * params.beta2 + grad_squared * + (1.0 - params.beta2)) + with tf.control_dependencies([v_update]): + denom = tf.sqrt(slots["adam_v"] / (1.0 - beta2_pow)) + params.epsilon + + # compute momentum if applicable + if params.beta1 != 0.0: + m_update = tf.assign(slots["adam_m"], + slots["adam_m"] * params.beta1 + grad_var * + (1.0 - params.beta1)) + with tf.control_dependencies([m_update]): + grad_var = slots["adam_m"] + + # update var + subtrahend = lrate * grad_var / denom + new_val = _quantize(_dequantize(var, params) - subtrahend, params) + return tf.assign(var, new_val) + + +def _create_diet_optimizer(params): + if params.optimizer == "DietAdam": + return DietAdamOptimizer(params) + else: + raise ValueError("Unrecognized diet optimizer") + + +def _quantize(x, params, randomize=True): + """Quantize x according to params, optionally randomizing the rounding.""" + if not params.quantize: + return x + + if not randomize: + return tf.bitcast( + tf.cast(x / params.quantization_scale, tf.int16), tf.float16) + + abs_x = tf.abs(x) + sign_x = tf.sign(x) + y = abs_x / params.quantization_scale + y = tf.floor(y + tf.random_uniform(tf.shape(x))) + y = tf.minimum(y, tf.int16.max) * sign_x + q = tf.bitcast(tf.cast(y, tf.int16), tf.float16) + return q + + +def _dequantize(q, params): + """Dequantize q according to params.""" + if not params.quantize: + return q + return tf.to_float(tf.bitcast(q, tf.int16)) * params.quantization_scale + + +def make_diet_var_getter(params): + """Create a custom variable getter for diet variables according to params.""" + + def diet_var_initializer(shape, dtype, partition_info=None): + del dtype + del partition_info + + with common_layers.fn_device_dependency("diet_init") as out_deps: + float_range = math.sqrt(3) + ret = tf.random_uniform(shape, -float_range, float_range) + if params.quantize: + ret = _quantize(ret, params, randomize=False) + out_deps.append(ret) + return ret + + def diet_var_getter(getter, **kwargs): + """Get diet variable and return it dequantized.""" + if params.quantize: + kwargs["dtype"] = tf.float16 + kwargs["initializer"] = diet_var_initializer + kwargs["trainable"] = False + + base_var = getter(**kwargs) + + dequantized = _dequantize(base_var, params) + + if not hasattr(params, "dequantized"): + params.dequantized = defaultdict(list) + params.dequantized[base_var.name].append(dequantized) + + return dequantized + + return diet_var_getter + + +def _fn_with_diet_vars(fn, args, params): + """Call function with args; use diet variables according to params.""" + + vs_ctr = [] + + def grad_fn(inputs, variables, outputs, output_grads): + del outputs # recomputing below + with common_layers.fn_device_dependency("diet_grad", + output_grads[0].device) as out_dep: + with tf.variable_scope(vs_ctr[0], reuse=True): + outputs = fn(*inputs) + + variables = [common_layers.underlying_variable_ref(v) for v in variables] + dequantized_variables = [ + params.dequantized[v.name][-1] for v in variables + ] + + grads = tf.gradients(outputs, inputs + dequantized_variables, + output_grads) + grad_inputs = grads[:len(inputs)] + grad_variables = grads[len(inputs):] + + opt = _create_diet_optimizer(params) + + # Apply grad_variables here + var_updates = [] + for v, dv in zip(variables, grad_variables): + with tf.variable_scope(vs_ctr[0].name): + opt.create_slots(v) + update_op = opt.update_variable(v, dv) + var_updates.append(update_op) + + with tf.control_dependencies(var_updates): + grad_inputs = [tf.identity(dx) for dx in grad_inputs] + + out_dep.append(grad_inputs) + + return grad_inputs, [None] * len(variables) + + @common_layers.fn_with_custom_grad(grad_fn, use_global_vars=True) + def forward(*inputs): + with tf.variable_scope( + None, default_name="diet", + custom_getter=make_diet_var_getter(params)) as vs: + vs_ctr.append(vs) + outputs = fn(*inputs) + return outputs + + with common_layers.fn_device_dependency("diet_forward", + args[0].device) as out_dep: + outputs = forward(*args) + out_dep.append(outputs) + return outputs + + +def fn_with_diet_vars(params): + """Decorator for graph-building function to use diet variables.""" + params = copy.copy(params) + + def dec(fn): + + def wrapped(*args): + return _fn_with_diet_vars(fn, args, params) + + return wrapped + + return dec diff --git a/tensor2tensor/utils/diet_test.py b/tensor2tensor/utils/diet_test.py new file mode 100644 index 000000000..9c0c570cc --- /dev/null +++ b/tensor2tensor/utils/diet_test.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for common layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.utils import diet + +import tensorflow as tf + + +class DietVarTest(tf.test.TestCase): + + def testDiet(self): + + params = diet.diet_adam_optimizer_params() + + @diet.fn_with_diet_vars(params) + def model_fn(x): + y = tf.layers.dense(x, 10, use_bias=False) + return y + + @diet.fn_with_diet_vars(params) + def model_fn2(x): + y = tf.layers.dense(x, 10, use_bias=False) + return y + + x = tf.random_uniform((10, 10)) + y = model_fn(x) + 10. + y = model_fn2(y) + 10. + grads = tf.gradients(y, [x]) + with tf.control_dependencies(grads): + incr_step = tf.assign_add(tf.train.get_or_create_global_step(), 1) + + train_op = tf.group(incr_step, *grads) + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + orig_vals = sess.run(tf.global_variables()) + for _ in xrange(10): + sess.run(train_op) + new_vals = sess.run(tf.global_variables()) + + different = [] + for old, new in zip(orig_vals, new_vals): + try: + self.assertAllClose(old, new) + except AssertionError: + different.append(True) + self.assertEqual(len(different), len(tf.global_variables())) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensor2tensor/utils/input_fn_builder.py b/tensor2tensor/utils/input_fn_builder.py index d1b68aa02..c31ba0f31 100644 --- a/tensor2tensor/utils/input_fn_builder.py +++ b/tensor2tensor/utils/input_fn_builder.py @@ -90,13 +90,11 @@ def input_fn(): p_hparams = hparams.problems[n] with tf.name_scope("problem_%d" % n): with tf.device("/cpu:0"): # Input reading on CPU - capacity = p_hparams.max_expected_batch_size_per_shard - capacity *= num_datashards - examples = data_reader.input_pipeline( + capacity = ( + p_hparams.max_expected_batch_size_per_shard * num_datashards) + feature_map = data_reader.input_pipeline( problem_instance, data_file_patterns and data_file_patterns[n], - capacity, mode, hparams) - feature_map = data_reader.batch_examples( - examples, + capacity, mode, hparams, data_reader.hparams_to_batching_scheme( hparams, shard_multiplier=num_datashards, diff --git a/tensor2tensor/utils/model_builder.py b/tensor2tensor/utils/model_builder.py index da33e1e40..24c17ca9e 100644 --- a/tensor2tensor/utils/model_builder.py +++ b/tensor2tensor/utils/model_builder.py @@ -46,6 +46,24 @@ IMAGE_DECODE_LENGTH = 100 +def log_variable_sizes(var_list, tag): + """Log the sizes and shapes of variables, and the total size. + + Args: + var_list: a list of varaibles + tag: a string + """ + name_to_var = {v.name: v for v in var_list} + total_size = 0 + for v_name in sorted(list(name_to_var)): + v = name_to_var[v_name] + v_size = int(np.prod(np.array(v.shape.as_list()))) + tf.logging.info("Weight %s\tshape %s\tsize %d", + v.name[:-2].ljust(80), str(v.shape).ljust(20), v_size) + total_size += v_size + tf.logging.info("%s Total size: %d", tag, total_size) + + def build_model_fn(model, hparams): """Returns a function to build the model. @@ -150,6 +168,9 @@ def model_fn(features, targets, mode): dp = devices.data_parallelism() + tf.get_variable_scope().set_initializer(initializer()) + is_training = mode == tf.contrib.learn.ModeKeys.TRAIN + # Add input statistics for incoming features. with tf.name_scope("input_stats"): for (k, v) in six.iteritems(features): @@ -157,13 +178,28 @@ def model_fn(features, targets, mode): tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n) tf.summary.scalar("%s_length" % k, tf.shape(v)[1]) nonpadding = tf.to_float(tf.not_equal(v, 0)) - tf.summary.scalar("%s_nonpadding_tokens" % k, - tf.reduce_sum(nonpadding)) + nonpadding_tokens = tf.reduce_sum(nonpadding) + if k == "targets": + targets_nonpadding_tokens = nonpadding_tokens + tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens) tf.summary.scalar("%s_nonpadding_fraction" % k, tf.reduce_mean(nonpadding)) - tf.get_variable_scope().set_initializer(initializer()) - train = mode == tf.contrib.learn.ModeKeys.TRAIN + # The new data reader occasionally emits very small batches, which + # cause the examples in those batches to be grossly overweighted. + # We decrease the loss proportionally to the ratio of the size of this + # batch to the size of the largest training batch ever. + # TODO(noam): to be more sophisticated, we could keep separate + # maxima based on problem choice. + max_nonpadding_var = tf.get_variable( + "max_nonpadding", shape=[], + initializer=tf.ones_initializer(), trainable=False) + max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens) + if is_training: + with tf.control_dependencies( + [tf.assign(max_nonpadding_var, max_nonpadding)]): + small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding + tf.summary.scalar("small_batch_multiplier", small_batch_multiplier) # Get multi-problem logits and loss based on features["problem_choice"]. loss_variable_names = [] @@ -186,7 +222,7 @@ def nth_model(n): alpha=FLAGS.decode_alpha, decode_length=FLAGS.decode_extra_length) # In distributed mode, we build graph for problem=0 and problem=worker_id. - skipping_is_on = my_hp.problem_choice == "distributed" and train + skipping_is_on = my_hp.problem_choice == "distributed" and is_training problem_worker_id = FLAGS.worker_id % len(my_hp.problems) skip_this_one = n != 0 and n % FLAGS.worker_replicas != problem_worker_id # On worker 0 also build graph for problems <= 1. @@ -208,9 +244,13 @@ def nth_model(n): ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + loss_value * 0.1)) total_loss += loss_value - with tf.variable_scope(tf.get_variable_scope(), reuse=True): - # Total loss was already constructed on input. - loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n) + try: # Total loss avg might be reused or not, we try both. + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + # Total loss was already constructed on input. + loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n) + except ValueError: + loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n, + initializer=100.0, trainable=False) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1)) with tf.variable_scope("train_stats"): # Count steps for this problem. @@ -248,9 +288,6 @@ def nth_model(n): sharded_logits, total_loss = result_list[1:], result_list[0] if mode == tf.contrib.learn.ModeKeys.EVAL: logits = tf.concat(sharded_logits, 0) - if FLAGS.eval_print: - logits = tf.Print( - logits, [features["inputs"], logits], "EVAL PRINT", summarize=10000) # For evaluation, return the logits layer as our predictions. run_info["predictions"] = logits train_op = None @@ -288,8 +325,6 @@ def nth_model(n): for v_name in sorted(list(all_weights)): v = all_weights[v_name] v_size = int(np.prod(np.array(v.shape.as_list()))) - tf.logging.info("Weight %s\tshape %s\tsize %d", - v.name[:-2].ljust(80), str(v.shape).ljust(20), v_size) total_size += v_size if my_hp.weight_decay > 0.0 and len(v.shape.as_list()) > 1: # Add weight regularization if set and the weight is not a bias (dim>1). @@ -305,11 +340,14 @@ def nth_model(n): noise_op = v.assign_add(noise) with tf.control_dependencies([noise_op]): total_loss = tf.identity(total_loss) - tf.logging.info("Total trainable variables size: %d", total_size) if my_hp.weight_decay > 0.0: total_loss += weight_decay_loss * my_hp.weight_decay + if is_training: + total_loss *= small_batch_multiplier total_loss = tf.identity(total_loss, name="total_loss") - + log_variable_sizes(tf.trainable_variables(), "Trainable Variables") + diet_vars = [v for v in tf.global_variables() if hasattr(v, "optimizer")] + log_variable_sizes(diet_vars, "Diet Varaibles") # Define the train_op for the TRAIN mode. opt = _ConditionalOptimizer(my_hp.optimizer, learning_rate, my_hp) tf.logging.info("Computing gradients for global model_fn.") @@ -319,7 +357,7 @@ def nth_model(n): train_op = tf.contrib.layers.optimize_loss( name="training", loss=total_loss, - global_step=tf.contrib.framework.get_global_step(), + global_step=tf.train.get_global_step(), learning_rate=learning_rate, clip_gradients=my_hp.clip_grad_norm or None, gradient_noise_scale=hparams.grad_noise_scale or None, diff --git a/tensor2tensor/utils/registry.py b/tensor2tensor/utils/registry.py index f5d83cbf1..f1db2f36c 100644 --- a/tensor2tensor/utils/registry.py +++ b/tensor2tensor/utils/registry.py @@ -90,8 +90,30 @@ def _reset(): ctr.clear() -def _default_name(obj): - return _convert_camel_to_snake(obj.__name__) +def _default_name(obj_class): + """Convert a class name to the registry's default name for the class. + + Args: + obj_class: the name of a class + + Returns: + The registry's default name for the class. + """ + + return _convert_camel_to_snake(obj_class.__name__) + + +def default_object_name(obj): + """Convert an object to the registry's default name for the object class. + + Args: + obj: an object instance + + Returns: + The registry's default name for the class of the object. + """ + + return _default_name(obj.__class__) def register_model(name=None): diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 8fcf2482d..d3fc6dac1 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -327,7 +327,7 @@ def infer_step(recent_output, recent_logits, unused_loss): # Assuming we have one shard for logits. logits = tf.concat([recent_logits, logits[0][:, -1:]], 1) - loss = sum(losses.values()) + loss = sum([l for l in losses.values() if l is not None]) return samples, logits, loss # Create an initial output tensor. This will be passed diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index 5682ae820..fa9d9233e 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -62,7 +62,6 @@ flags.DEFINE_integer("train_steps", 250000, "The number of steps to run training for.") flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") -flags.DEFINE_bool("eval_print", False, "Print eval logits and predictions.") flags.DEFINE_bool("eval_run_autoregressive", False, "Run eval autoregressively where we condition on previous" "generated output instead of the actual target.") @@ -72,11 +71,11 @@ "Optimize ops placement with experimental session options.") flags.DEFINE_integer("keep_checkpoint_every_n_hours", 10000, "Number of hours between each checkpoint to be saved. " - "The default value of 10,000 hours effectively disables the feature.") + "The default value 10,000 hours effectively disables it.") flags.DEFINE_integer("save_checkpoints_secs", 0, - "Save checkpoints every this many seconds. " - "Default=0 means let tensorflow.contrib.learn.python.learn decide, " - "which is currently equivalent to 600, i.e. 10 minutes.") + "Save checkpoints every this many seconds. " + "Default=0 means let tensorflow.contrib.learn.python.learn" + " decide, which is currently set to 600 = 10 minutes.") # Distributed training flags flags.DEFINE_string("master", "", "Address of TensorFlow master.") @@ -150,7 +149,8 @@ def experiment_fn(output_dir): def create_experiment(output_dir, data_dir, model_name, train_steps, eval_steps): """Create Experiment.""" - hparams = create_hparams(FLAGS.hparams_set, data_dir) + hparams = create_hparams(FLAGS.hparams_set, FLAGS.problems, data_dir, + passed_hparams=FLAGS.hparams) estimator, input_fns = create_experiment_components( hparams=hparams, output_dir=output_dir, @@ -203,7 +203,7 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name): worker_replicas=FLAGS.worker_replicas, worker_id=FLAGS.worker_id) estimator = tf.contrib.learn.Estimator( - model_fn=model_builder.build_model_fn(model_name, hparams=hparams), + model_fn=model_builder.build_model_fn(model_name, hparams), model_dir=output_dir, config=tf.contrib.learn.RunConfig( master=FLAGS.master, @@ -212,7 +212,8 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name): session_config=session_config(), keep_checkpoint_max=FLAGS.keep_checkpoint_max, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, - save_checkpoints_secs=FLAGS.save_checkpoints_secs,)) + save_checkpoints_secs=FLAGS.save_checkpoints_secs)) + # Store the hparams in the estimator as well estimator.hparams = hparams return estimator, { @@ -248,7 +249,7 @@ def add_problem_hparams(hparams, problems): return hparams -def create_hparams(params_id, data_dir): +def create_hparams(params_id, problems, data_dir, passed_hparams=None): """Returns hyperparameters, including any flag value overrides. If the hparams FLAG is set, then it will use any values specified in @@ -257,7 +258,9 @@ def create_hparams(params_id, data_dir): Args: params_id: which set of parameters to choose (must be in _PARAMS above). + problems: the string with problem names to get problem_hparams from. data_dir: the directory containing the training data. + passed_hparams: command-line overrides for some hparams. Returns: The hyperparameters as a tf.contrib.training.HParams object. @@ -265,10 +268,10 @@ def create_hparams(params_id, data_dir): hparams = registry.hparams(params_id)() hparams.add_hparam("data_dir", data_dir) # Command line flags override any of the preceding hyperparameter values. - if FLAGS.hparams: - hparams = hparams.parse(FLAGS.hparams) + if passed_hparams: + hparams = hparams.parse(passed_hparams) - return add_problem_hparams(hparams, FLAGS.problems) + return add_problem_hparams(hparams, problems) def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): @@ -352,7 +355,7 @@ def get_data_filepatterns(data_dir, mode): def decode(estimator): if FLAGS.decode_interactive: decoding.decode_interactively(estimator) - elif FLAGS.decode_from_file is not None: + elif FLAGS.decode_from_file is not None and FLAGS.decode_from_file is not "": decoding.decode_from_file(estimator, FLAGS.decode_from_file) elif FLAGS.decode_from_dataset: decoding.decode_from_dataset(estimator) diff --git a/tensor2tensor/utils/trainer_utils_test.py b/tensor2tensor/utils/trainer_utils_test.py index 61156f227..6cc654d26 100644 --- a/tensor2tensor/utils/trainer_utils_test.py +++ b/tensor2tensor/utils/trainer_utils_test.py @@ -19,11 +19,15 @@ from __future__ import division from __future__ import print_function +import os +import shutil + # Dependency imports from tensor2tensor.data_generators import algorithmic from tensor2tensor.data_generators import generator_utils from tensor2tensor.models import transformer +from tensor2tensor.utils import model_builder from tensor2tensor.utils import registry from tensor2tensor.utils import trainer_utils @@ -60,9 +64,13 @@ class TrainerUtilsTest(tf.test.TestCase): @classmethod def setUpClass(cls): + tmp_dir = tf.test.get_temp_dir() + shutil.rmtree(tmp_dir) + os.mkdir(tmp_dir) + # Generate a small test dataset FLAGS.problems = "tiny_algo" - TrainerUtilsTest.data_dir = tf.test.get_temp_dir() + TrainerUtilsTest.data_dir = tmp_dir registry.problem(FLAGS.problems).generate_data(TrainerUtilsTest.data_dir, None) @@ -85,6 +93,51 @@ def testSingleStep(self): eval_steps=1) exp.test() + def testSingleEvalStepRawSession(self): + """Illustrate how to run a T2T model in a raw session.""" + + # Set model name, hparams, problems as would be set on command line. + model_name = "transformer" + FLAGS.hparams_set = "transformer_test" + FLAGS.problems = "tiny_algo" + data_dir = "/tmp" # Used only when a vocab file or such like is needed. + + # Create the problem object, hparams, model_fn, placeholders, features dict. + encoders = registry.problem(FLAGS.problems).feature_encoders(data_dir) + hparams = trainer_utils.create_hparams( + FLAGS.hparams_set, FLAGS.problems, data_dir) + model_fn = model_builder.build_model_fn(model_name, hparams) + inputs_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. + batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1]) # Make it 4D. + targets_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. + batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1]) # Make it 4D. + features = {"inputs": batch_inputs, + "problem_choice": 0, # We run on the first problem here. + "input_space_id": hparams.problems[0].input_space_id, + "target_space_id": hparams.problems[0].target_space_id} + + # Now set a mode and create the graph by invoking model_fn. + mode = tf.contrib.learn.ModeKeys.EVAL + predictions_dict, _, _ = model_fn( # In INFER mode targets can be None. + features, batch_targets, mode) + predictions = tf.squeeze( # These are not images, axis=2,3 are not needed. + predictions_dict["predictions"], axis=[2, 3]) + + # Having the graph, let's run it on some data. + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + inputs = "0 1 0" + targets = "0 1 0" + # Encode from raw string to numpy input array using problem encoders. + inputs_numpy = encoders["inputs"].encode(inputs) + targets_numpy = encoders["targets"].encode(targets) + # Feed the encoded inputs and targets and run session. + feed = {inputs_ph: inputs_numpy, targets_ph: targets_numpy} + np_predictions = sess.run(predictions, feed) + # Check that the result has the correct shape: batch x length x vocab_size + # where, for us, batch = 1, length = 3, vocab_size = 4. + self.assertEqual(np_predictions.shape, (1, 3, 4)) + if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/utils/yellowfin.py b/tensor2tensor/utils/yellowfin.py index c90d43b83..450875fa5 100644 --- a/tensor2tensor/utils/yellowfin.py +++ b/tensor2tensor/utils/yellowfin.py @@ -20,8 +20,8 @@ from __future__ import print_function # Dependency imports + import tensorflow as tf -from tensorflow.python.framework import ops # Values for gate_gradients. @@ -49,6 +49,7 @@ def __init__(self, name="YellowFin", use_nesterov=False): """Construct a new YellowFin optimizer. + Implemented as a wrapper around tf.train.MomentumOptimizer Args: @@ -113,11 +114,11 @@ def __init__(self, # Gradient Clipping Threshold. if clip_thresh is not None: - self._clip_thresh_var = \ - tf.get_variable("YF_clip_thresh", - dtype=tf.float32, - trainable=False, - initializer=tf.constant(clip_thresh)) + self._clip_thresh_var = tf.get_variable( + "YF_clip_thresh", + dtype=tf.float32, + trainable=False, + initializer=tf.constant(clip_thresh)) else: self._clip_thresh_var = None @@ -201,7 +202,7 @@ def _curvature_range(self): self._curv_win = tf.get_variable("curv_win", dtype=tf.float32, trainable=False, - shape=[self.curvature_window_width, ], + shape=[self.curvature_window_width,], initializer=tf.zeros_initializer) # We use log smoothing for curvature range self._curv_win = tf.scatter_update(self._curv_win, @@ -226,12 +227,11 @@ def _curvature_range(self): self._h_max = tf.exp( tf.identity(self._moving_averager.average(self._h_max_t))) if self._sparsity_debias: - self._h_min = self._h_min * self._sparsity_avg - self._h_max = self._h_max * self._sparsity_avg + self._h_min *= self._sparsity_avg + self._h_max *= self._sparsity_avg curv_range_ops.append(avg_op) return curv_range_ops # h_max_t, h_min_t - def _grad_variance(self): """Estimate of gradient Variance. @@ -241,7 +241,7 @@ def _grad_variance(self): grad_var_ops = [] tensor_to_avg = [] for t, g in zip(self._vars, self._grad): - if isinstance(g, ops.IndexedSlices): + if isinstance(g, tf.IndexedSlices): tensor_to_avg.append( tf.reshape(tf.unsorted_segment_sum(g.values, g.indices, @@ -265,7 +265,6 @@ def _grad_variance(self): self._grad_var *= self._sparsity_avg return grad_var_ops # C_t - def _dist_to_opt(self): """Distance to optimum. @@ -292,26 +291,23 @@ def _dist_to_opt(self): self._dist_to_opt_avg /= tf.sqrt(self._sparsity_avg) return dist_to_opt_ops # D_t - def _grad_sparsity(self): - """ + """Gradient sparsity.""" # If the sparse minibatch gradient has 10 percent of its entries # non-zero, its sparsity is 0.1. # The norm of dense gradient averaged from full dataset # are roughly estimated norm of minibatch # sparse gradient norm * sqrt(sparsity) # An extension maybe only correct the sparse blob. - """ non_zero_cnt = tf.add_n([tf.count_nonzero(g) for g in self._grad]) all_entry_cnt = tf.add_n([tf.size(g) for g in self._grad]) - self._sparsity = tf.cast(non_zero_cnt, self._grad[0].dtype) \ - / tf.cast(all_entry_cnt, self._grad[0].dtype) - avg_op = self._moving_averager.apply([self._sparsity, ]) + self._sparsity = tf.cast(non_zero_cnt, self._grad[0].dtype) + self._sparsity /= tf.cast(all_entry_cnt, self._grad[0].dtype) + avg_op = self._moving_averager.apply([self._sparsity,]) with tf.control_dependencies([avg_op]): self._sparsity_avg = self._moving_averager.average(self._sparsity) return avg_op - def _prepare_variables(self): """Prepare Variables for YellowFin. @@ -331,7 +327,7 @@ def _prepare_variables(self): # Gradient squared for v, g in zip(self._vars, self._grad): if g is None: continue - with ops.colocate_with(v): + with tf.colocate_with(v): self._grad_squared.append(tf.square(g)) # Norm squared. @@ -355,9 +351,8 @@ def _prepare_variables(self): prepare_variables_op.append(avg_op) return tf.group(*prepare_variables_op) - def _get_cubic_root(self): - """ + """Get the cubic root.""" # We have the equation x^2 D^2 + (1-x)^4 * C / h_min^2 # where x = sqrt(mu). # We substitute x, which is sqrt(mu), with x = y + 1. @@ -366,27 +361,26 @@ def _get_cubic_root(self): # We use the Vieta's substution to compute the root. # There is only one real solution y (which is in [0, 1] ). # http://mathworld.wolfram.com/VietasSubstitution.html - """ - assert_array = \ - [tf.Assert( - tf.logical_not(tf.is_nan(self._dist_to_opt_avg)), - [self._dist_to_opt_avg, ]), \ - tf.Assert( - tf.logical_not(tf.is_nan(self._h_min)), - [self._h_min,]), \ - tf.Assert( - tf.logical_not(tf.is_nan(self._grad_var)), - [self._grad_var,]), \ - tf.Assert( - tf.logical_not(tf.is_inf(self._dist_to_opt_avg)), - [self._dist_to_opt_avg, ]), \ - tf.Assert( - tf.logical_not(tf.is_inf(self._h_min)), - [self._h_min,]), \ - tf.Assert( - tf.logical_not(tf.is_inf(self._grad_var)), - [self._grad_var,])] - + assert_array = [ + tf.Assert( + tf.logical_not(tf.is_nan(self._dist_to_opt_avg)), + [self._dist_to_opt_avg,]), + tf.Assert( + tf.logical_not(tf.is_nan(self._h_min)), + [self._h_min,]), + tf.Assert( + tf.logical_not(tf.is_nan(self._grad_var)), + [self._grad_var,]), + tf.Assert( + tf.logical_not(tf.is_inf(self._dist_to_opt_avg)), + [self._dist_to_opt_avg,]), + tf.Assert( + tf.logical_not(tf.is_inf(self._h_min)), + [self._h_min,]), + tf.Assert( + tf.logical_not(tf.is_inf(self._grad_var)), + [self._grad_var,]) + ] with tf.control_dependencies(assert_array): p = self._dist_to_opt_avg**2 * self._h_min**2 / 2 / self._grad_var w3 = (-tf.sqrt(p**2 + 4.0 / 27.0 * p**3) - p) / 2.0 @@ -395,7 +389,6 @@ def _get_cubic_root(self): x = y + 1 return x - def _get_lr_tensor(self): """Get lr minimzing the surrogate. @@ -405,7 +398,6 @@ def _get_lr_tensor(self): lr = (1.0 - tf.sqrt(self._mu))**2 / self._h_min return lr - def _get_mu_tensor(self): """Get the min mu which minimize the surrogate. @@ -415,10 +407,9 @@ def _get_mu_tensor(self): root = self._get_cubic_root() dr = self._h_max / self._h_min mu = tf.maximum( - root**2, ((tf.sqrt(dr) - 1) / (tf.sqrt(dr) + 1))**2) + root**2, ((tf.sqrt(dr) - 1) / (tf.sqrt(dr) + 1))**2) return mu - def _yellowfin(self): """YellowFin auto-tuning optimizer based on momentum SGD. @@ -448,11 +439,11 @@ def _yellowfin(self): # approximation after a single step while keeping all directions in the # robust region. self._mu = tf.identity(tf.cond(self._do_tune, - lambda: self._get_mu_tensor(), + self._get_mu_tensor, lambda: self._mu_var)) with tf.control_dependencies([self._mu]): self._lr = tf.identity(tf.cond(self._do_tune, - lambda: self._get_lr_tensor(), + self._get_lr_tensor, lambda: self._lr_var)) # Tune learning rate and momentum. @@ -465,12 +456,10 @@ def _yellowfin(self): yellowfin_ops = tf.group(*yellowfin_ops) return yellowfin_ops - def get_name(self): - """Get Optimizer Name""" + """Get optimizer name.""" return self._momentum_optimizer.get_name() - def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Applying gradients aand tune hyperparams with YellowFin. @@ -488,7 +477,6 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): YellowFin ops(Curvature, Variance, Distance) ops, SingleStep and lr_mu tuning ops, Step increment ops. - """ self._grad, self._vars = zip(*[(g, t) for g, t in grads_and_vars if g is not None]) @@ -517,7 +505,7 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): # does not support indexed slice for sparse gradients. # The alternative dependencies here might be slightly slower due # to less parallelization. - with tf.control_dependencies([apply_grad_op, ]): + with tf.control_dependencies([apply_grad_op,]): prepare_variables_op = self._prepare_variables() with tf.variable_scope("yellowfin"): @@ -533,7 +521,6 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): yellowfin_op, self._increment_step_op) - def compute_gradients(self, loss, var_list, @@ -566,6 +553,7 @@ def compute_gradients(self, A list of (gradient, variable) pairs. Variable is always present, but gradient can be None. """ + del global_step, name # Unused for now. return self._momentum_optimizer.compute_gradients( loss, var_list=var_list, @@ -574,7 +562,6 @@ def compute_gradients(self, colocate_gradients_with_ops=colocate_gradients_with_ops, grad_loss=grad_loss) - def minimize(self, loss, global_step=None, @@ -637,11 +624,8 @@ def minimize(self, global_step=global_step, name=name) - def get_slot(self, var, name): - """ - Return a slot named `name` created for `var` by - the underlying MomentumOptimizer. + """Return a slot named `name` created for `var`. Args: var: A variable passed to `minimize()` or `apply_gradients()`. @@ -653,9 +637,7 @@ def get_slot(self, var, name): return self._momentum_optimizer.get_slot(var, name) def get_slot_names(self): - """ - Return a list of the names of the slots created by the - underlying MomentumOptimizer. + """Return a list of the names of the slots using MomentumOptimizer. Returns: A list of strings. diff --git a/tensor2tensor/visualization/TransformerVisualization.ipynb b/tensor2tensor/visualization/TransformerVisualization.ipynb new file mode 100644 index 000000000..ef1c7b45d --- /dev/null +++ b/tensor2tensor/visualization/TransformerVisualization.ipynb @@ -0,0 +1,477 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Create Your Own Visualizations!\n", + "Instructions:\n", + "1. Install tensor2tensor and train up a Transformer model following the instruction in the repository https://github.com/tensorflow/tensor2tensor.\n", + "2. Update cell 3 to point to your checkpoint, it is currently set up to read from the default checkpoint location that would be created from following the instructions above.\n", + "3. If you used custom hyper parameters then update cell 4.\n", + "4. Run the notebook!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import absolute_import\n", + "from __future__ import division\n", + "from __future__ import print_function\n", + "\n", + "import json\n", + "\n", + "import tensorflow as tf\n", + "import numpy as np\n", + "\n", + "from tensor2tensor.utils import trainer_utils as utils\n", + "from tensor2tensor.visualization import attention" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "require.config({\n", + " paths: {\n", + " d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min'\n", + " }\n", + "});" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%javascript\n", + "require.config({\n", + " paths: {\n", + " d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min'\n", + " }\n", + "});" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/llion/t2t_train/wmt_ende_tokens_32k/transformer-transformer_base_single_gpu\n" + ] + } + ], + "source": [ + "import os\n", + "# PUT THE MODEL YOU WANT TO LOAD HERE!\n", + "\n", + "PROBLEM = 'wmt_ende_tokens_32k'\n", + "MODEL = 'transformer'\n", + "HPARAMS = 'transformer_base_single_gpu'\n", + "\n", + "DATA_DIR=os.path.expanduser('~/t2t_data')\n", + "TRAIN_DIR=os.path.expanduser('~/t2t_train/%s/%s-%s' % (PROBLEM, MODEL, HPARAMS))\n", + "print(TRAIN_DIR)\n", + "\n", + "FLAGS = tf.flags.FLAGS\n", + "FLAGS.problems = PROBLEM\n", + "FLAGS.hparams_set = HPARAMS\n", + "FLAGS.data_dir = DATA_DIR\n", + "FLAGS.model = MODEL" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:datashard_devices: ['gpu:0']\n", + "INFO:tensorflow:caching_devices: None\n" + ] + } + ], + "source": [ + "hparams = utils.create_hparams(HPARAMS, DATA_DIR)\n", + "\n", + "# SET EXTRA HYPER PARAMS HERE!\n", + "# e.g.\n", + "# hparams.batch_size = 1024\n", + "\n", + "num_datashards = utils.devices.data_parallelism().n\n", + "\n", + "problems_data = utils.get_data_filepatterns(\n", + " DATA_DIR, tf.contrib.learn.ModeKeys.EVAL)\n", + "input_fn = utils.input_fn_builder.build_input_fn(\n", + " mode=tf.contrib.learn.ModeKeys.EVAL,\n", + " hparams=hparams,\n", + " data_file_patterns=problems_data,\n", + " num_datashards=num_datashards)\n", + "\n", + "inputs, target = input_fn()\n", + "features = inputs\n", + "features['targets'] = target" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def encode(string):\n", + " subtokenizer = hparams.problems[0].vocabulary['inputs']\n", + " return [subtokenizer.encode(string) + [1] + [0]]\n", + "\n", + "def decode(ids):\n", + " return hparams.problems[0].vocabulary['targets'].decode(np.squeeze(ids))\n", + "\n", + "def to_tokens(ids):\n", + " ids = np.squeeze(ids)\n", + " subtokenizer = hparams.problems[0].vocabulary['targets']\n", + " tokens = []\n", + " for _id in ids:\n", + " if _id == 0:\n", + " tokens.append('')\n", + " elif _id == 1:\n", + " tokens.append('')\n", + " else:\n", + " tokens.append(subtokenizer._subtoken_id_to_subtoken_string(_id))\n", + " return tokens" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:datashard_devices: ['gpu:0']\n", + "INFO:tensorflow:caching_devices: None\n", + "INFO:tensorflow:Doing model_fn_body took 1.881 sec.\n", + "INFO:tensorflow:This model_fn took 2.023 sec.\n" + ] + } + ], + "source": [ + "model_fn=utils.model_builder.build_model_fn(MODEL, hparams=hparams)\n", + "sharded_logits, training_loss, extra_loss = model_fn(features, target, tf.contrib.learn.ModeKeys.EVAL)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:datashard_devices: ['gpu:0']\n", + "INFO:tensorflow:caching_devices: None\n", + "INFO:tensorflow:Beam Decoding with beam size 4\n", + "INFO:tensorflow:Doing model_fn_body took 1.393 sec.\n", + "INFO:tensorflow:This model_fn took 1.504 sec.\n" + ] + } + ], + "source": [ + "with tf.variable_scope(tf.get_variable_scope(), reuse=True):\n", + " beam_out = model_fn(features, target, tf.contrib.learn.ModeKeys.INFER)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Session" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Restoring parameters from /home/llion/t2t_train/wmt_ende_tokens_32k/transformer-transformer_base_single_gpu/model.ckpt-250000\n", + "INFO:tensorflow:Starting standard services.\n", + "INFO:tensorflow:Saving checkpoint to path /home/llion/t2t_train/wmt_ende_tokens_32k/transformer-transformer_base_single_gpu/model.ckpt\n", + "INFO:tensorflow:Starting queue runners.\n" + ] + }, + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sv = tf.train.Supervisor(\n", + " logdir=TRAIN_DIR,\n", + " global_step=tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step'))\n", + "sess = sv.PrepareSession(config=tf.ConfigProto(allow_soft_placement=True))\n", + "sv.StartQueueRunners(\n", + " sess,\n", + " tf.get_default_graph().get_collection(tf.GraphKeys.QUEUE_RUNNERS))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Get the attention tensors from the graph.\n", + "# This need to be done using the training graph since the inference uses a tf.while_loop\n", + "# and you cant fetch tensors from inside a while_loop.\n", + "\n", + "enc_atts = []\n", + "dec_atts = []\n", + "encdec_atts = []\n", + "\n", + "for i in range(hparams.num_hidden_layers):\n", + " enc_att = tf.get_default_graph().get_operation_by_name(\n", + " \"body/model/parallel_0/body/encoder/layer_%i/self_attention/multihead_attention/dot_product_attention/attention_weights\" % i).values()[0]\n", + " dec_att = tf.get_default_graph().get_operation_by_name(\n", + " \"body/model/parallel_0/body/decoder/layer_%i/self_attention/multihead_attention/dot_product_attention/attention_weights\" % i).values()[0]\n", + " encdec_att = tf.get_default_graph().get_operation_by_name(\n", + " \"body/model/parallel_0/body/decoder/layer_%i/encdec_attention/multihead_attention/dot_product_attention/attention_weights\" % i).values()[0]\n", + "\n", + " enc_atts.append(enc_att)\n", + " dec_atts.append(dec_att)\n", + " encdec_atts.append(encdec_att)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test translation from the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:global_step/sec: 0\n", + "Input: For example, during the 2008 general election in Florida, 33% of early voters were African-Americans, who accounted however for only 13% of voters in the State.\n", + "Gold: Beispielsweise waren bei den allgemeinen Wahlen 2008 in Florida 33% der Wähler, die im Voraus gewählt haben, Afro-Amerikaner, obwohl sie nur 13% der Wähler des Bundesstaates ausmachen.\n", + "Gold out: So waren 33 den allgemeinen Wahlen im in der a 33 % der Frühjungdie nur Land die wurden, die ro- Amerikaner, die sie nur 13 % der Wähler im Staates staats betra.\n", + "INFO:tensorflow:Recording summary at step 250000.\n" + ] + } + ], + "source": [ + "inp, out, logits = sess.run([inputs['inputs'], target, sharded_logits['predictions']])\n", + "\n", + "print(\"Input: \", decode(inp[0]))\n", + "print(\"Gold: \", decode(out[0]))\n", + "logits = np.squeeze(logits[0])\n", + "tokens = np.argmax(logits, axis=1)\n", + "print(\"Gold out: \", decode(tokens))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualize Custom Sentence" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "eng = \"I have three dogs.\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ich habe drei Hunde.\n" + ] + } + ], + "source": [ + "inp_ids = encode(eng)\n", + "beam_decode = sess.run(beam_out[0]['outputs'], {\n", + " inputs['inputs']: np.expand_dims(np.expand_dims(inp_ids, axis=2), axis=3),\n", + "})\n", + "trans = decode(beam_decode[0])\n", + "print(trans)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "der = decode(beam_decode[0])\n", + "output_ids = encode(der)\n", + "\n", + "# Get attentions\n", + "np_enc_atts, np_dec_atts, np_encdec_atts = sess.run([enc_atts, dec_atts, encdec_atts], {\n", + " inputs['inputs']: np.expand_dims(np.expand_dims(inp_ids, axis=2), axis=3),\n", + " target: np.expand_dims(np.expand_dims(output_ids, axis=2), axis=3),\n", + "})" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "IPython.OutputArea.prototype._should_scroll = function(lines) {\n", + " return false;\n", + "}" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%javascript\n", + "IPython.OutputArea.prototype._should_scroll = function(lines) {\n", + " return false;\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interpreting the Visualizations\n", + "- The layers drop down allow you to view the different Transformer layers, 0-indexed of course.\n", + " - Tip: The first layer, last layer and 2nd to last layer are usually the most interpretable.\n", + "- The attention dropdown allows you to select different pairs of encoder-decoder attentions:\n", + " - All: Shows all types of attentions together. NOTE: There is no relation between heads of the same color - between the decoder self attention and decoder-encoder attention since they do not share parameters.\n", + " - Input - Input: Shows only the encoder self-attention.\n", + " - Input - Output: Shows the decoder’s attention on the encoder. NOTE: Every decoder layer attends to the final layer of encoder so the visualization will show the attention on the final encoder layer regardless of what layer is selected in the drop down.\n", + " - Output - Output: Shows only the decoder self-attention. NOTE: The visualization might be slightly misleading in the first layer since the text shown is the target of the decoder, the input to the decoder at layer 0 is this text with a GO symbol prepreded.\n", + "- The colored squares represent the different attention heads.\n", + " - You can hide or show a given head by clicking on it’s color.\n", + " - Double clicking a color will hide all other colors, double clicking on a color when it’s the only head showing will show all the heads again.\n", + "- You can hover over a word to see the individual attention weights for just that position.\n", + " - Hovering over the words on the left will show what that position attended to.\n", + " - Hovering over the words on the right will show what positions attended to it.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "inp_text = to_tokens(inp_ids)\n", + "out_text = to_tokens(output_ids)\n", + "\n", + "attention.show(inp_text, out_text, np_enc_atts, np_dec_atts, np_encdec_atts)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/tensor2tensor/visualization/attention.js b/tensor2tensor/visualization/attention.js new file mode 100644 index 000000000..ae2deb6bd --- /dev/null +++ b/tensor2tensor/visualization/attention.js @@ -0,0 +1,363 @@ +/** + * @fileoverview Transformer Visualization D3 javascript code. + */ + +requirejs(['jquery', 'd3'], +function($, d3) { + +var attention = window.attention; + +const TEXT_SIZE = 15; +const BOXWIDTH = TEXT_SIZE * 8; +const BOXHEIGHT = TEXT_SIZE * 1.5; +const WIDTH = 2000; +const HEIGHT = attention.all.bot_text.length * BOXHEIGHT * 2 + 100; +const MATRIX_WIDTH = 150; +const head_colours = d3.scale.category10(); +const CHECKBOX_SIZE = 20; + +function lighten(colour) { + var c = d3.hsl(colour); + var increment = (1 - c.l) * 0.6; + c.l += increment; + c.s -= increment; + return c; +} + +function transpose(mat) { + return mat[0].map(function(col, i) { + return mat.map(function(row) { + return row[i]; + }); + }); +} + +function zip(a, b) { + return a.map(function (e, i) { + return [e, b[i]]; + }); +} + + +function renderVis(id, top_text, bot_text, attention_heads, config) { + $(id).empty(); + var svg = d3.select(id) + .append('svg') + .attr("width", WIDTH) + .attr("height", HEIGHT); + + var att_data = []; + for (var i=0; i < attention_heads.length; i++) { + var att_trans = transpose(attention_heads[i]); + att_data.push(zip(attention_heads[i], att_trans)); + } + + renderText(svg, top_text, true, att_data, 0); + renderText(svg, bot_text, false, att_data, MATRIX_WIDTH + BOXWIDTH); + + renderAttentionHighlights(svg, att_data); + + svg.append("g").classed("attention_heads", true); + + renderAttention(svg, attention_heads); + + draw_checkboxes(config, 0, svg, attention_heads); +} + + +function renderText(svg, text, is_top, att_data, left_pos) { + var id = is_top ? "top" : "bottom"; + var textContainer = svg.append("svg:g") + .attr("id", id); + + textContainer.append("g").classed("attention_boxes", true) + .selectAll("g") + .data(att_data) + .enter() + .append("g") + .selectAll("rect") + .data(function(d) {return d;}) + .enter() + .append("rect") + .attr("x", function(d, i, j) { + return left_pos + box_offset(j); + }) + .attr("y", function(d, i) { + return (+1) * BOXHEIGHT; + }) + .attr("width", BOXWIDTH/active_heads()) + .attr("height", function() { return BOXHEIGHT; }) + .attr("fill", function(d, i, j) { + return head_colours(j); + }) + .style("opacity", 0.0); + + + var tokenContainer = textContainer.append("g").selectAll("g") + .data(text) + .enter() + .append("g"); + + tokenContainer.append("rect") + .classed("background", true) + .style("opacity", 0.0) + .attr("fill", "lightgray") + .attr("x", left_pos) + .attr("y", function(d, i) { + return (i+1) * BOXHEIGHT; + }) + .attr("width", BOXWIDTH) + .attr("height", BOXHEIGHT); + + var theText = tokenContainer.append("text") + .text(function(d) { return d; }) + .attr("font-size", TEXT_SIZE + "px") + .style("cursor", "default") + .style("-webkit-user-select", "none") + .attr("x", left_pos) + .attr("y", function(d, i) { + return (i+1) * BOXHEIGHT; + }); + + if (is_top) { + theText.style("text-anchor", "end") + .attr("dx", BOXWIDTH - TEXT_SIZE) + .attr("dy", TEXT_SIZE); + } else { + theText.style("text-anchor", "start") + .attr("dx", + TEXT_SIZE) + .attr("dy", TEXT_SIZE); + } + + tokenContainer.on("mouseover", function(d, index) { + textContainer.selectAll(".background") + .style("opacity", function(d, i) { + return i == index ? 1.0 : 0.0; + }); + + svg.selectAll(".attention_heads").style("display", "none"); + + svg.selectAll(".line_heads") // To get the nesting to work. + .selectAll(".att_lines") + .attr("stroke-opacity", function(d) { + return 1.0; + }) + .attr("y1", function(d, i) { + if (is_top) { + return (index+1) * BOXHEIGHT + (BOXHEIGHT/2); + } else { + return (i+1) * BOXHEIGHT + (BOXHEIGHT/2); + } + }) + .attr("x1", BOXWIDTH) + .attr("y2", function(d, i) { + if (is_top) { + return (i+1) * BOXHEIGHT + (BOXHEIGHT/2); + } else { + return (index+1) * BOXHEIGHT + (BOXHEIGHT/2); + } + }) + .attr("x2", BOXWIDTH + MATRIX_WIDTH) + .attr("stroke-width", 2) + .attr("stroke", function(d, i, j) { + return head_colours(j); + }) + .attr("stroke-opacity", function(d, i, j) { + if (is_top) {d = d[0];} else {d = d[1];} + if (config.head_vis[j]) { + if (d) { + return d[index]; + } else { + return 0.0; + } + } else { + return 0.0; + } + }); + + + function updateAttentionBoxes() { + var id = is_top ? "bottom" : "top"; + var the_left_pos = is_top ? MATRIX_WIDTH + BOXWIDTH : 0; + svg.select("#" + id) + .selectAll(".attention_boxes") + .selectAll("g") + .selectAll("rect") + .attr("x", function(d, i, j) { return the_left_pos + box_offset(j); }) + .attr("y", function(d, i) { return (i+1) * BOXHEIGHT; }) + .attr("width", BOXWIDTH/active_heads()) + .attr("height", function() { return BOXHEIGHT; }) + .style("opacity", function(d, i, j) { + if (is_top) {d = d[0];} else {d = d[1];} + if (config.head_vis[j]) + if (d) { + return d[index]; + } else { + return 0.0; + } + else + return 0.0; + + }); + } + + updateAttentionBoxes(); + }); + + textContainer.on("mouseleave", function() { + d3.select(this).selectAll(".background") + .style("opacity", 0.0); + + svg.selectAll(".att_lines").attr("stroke-opacity", 0.0); + svg.selectAll(".attention_heads").style("display", "inline"); + svg.selectAll(".attention_boxes") + .selectAll("g") + .selectAll("rect") + .style("opacity", 0.0); + }); +} + +function renderAttentionHighlights(svg, attention) { + var line_container = svg.append("g"); + line_container.selectAll("g") + .data(attention) + .enter() + .append("g") + .classed("line_heads", true) + .selectAll("line") + .data(function(d){return d;}) + .enter() + .append("line").classed("att_lines", true); +} + +function renderAttention(svg, attention_heads) { + var line_container = svg.selectAll(".attention_heads"); + line_container.html(null); + for(var h=0; h").val(i).text(i)); +} + +$("#layer").on('change', function(e) { + config.layer = +e.currentTarget.value; + render(); +}); + +$("#att_type").on('change', function(e) { + config.att_type = e.currentTarget.value; + render(); +}); + +$("button").on('click', visualize); + +visualize(); + +}); diff --git a/tensor2tensor/visualization/attention.py b/tensor2tensor/visualization/attention.py new file mode 100644 index 000000000..2c1f61c9c --- /dev/null +++ b/tensor2tensor/visualization/attention.py @@ -0,0 +1,150 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for postprocessing and displaying tranformer attentions. + +This module is deigned to be called from an ipython notebook. +""" + +import json +import os + +from IPython.display import HTML +from IPython.display import Javascript + +import numpy as np + +vis_html = """ + + Layer: + Attention: + +
+""" + + +__location__ = os.path.realpath( + os.path.join(os.getcwd(), os.path.dirname(__file__))) +vis_js = open(os.path.join(__location__, 'attention.js')).read() + + +def show(inp_text, out_text, enc_atts, dec_atts, encdec_atts): + attention = _get_attention( + inp_text, out_text, enc_atts, dec_atts, encdec_atts) + att_json = json.dumps(attention) + _show_attention(att_json) + + +def _show_attention(att_json): + display(HTML(vis_html)) # pylint: disable=undefined-variable + display(Javascript('window.attention = %s' % att_json)) # pylint: disable=undefined-variable + display(Javascript(vis_js)) # pylint: disable=undefined-variable + + +def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts): + """Compute representation of the attention ready for the d3 visualization. + + Args: + inp_text: list of strings, words to be displayed on the left of the vis + out_text: list of strings, words to be displayed on the right of the vis + enc_atts: numpy array, encoder self-attentions + [num_layers, batch_size, num_heads, enc_length, enc_length] + dec_atts: numpy array, decoder self-attentions + [num_layers, batch_size, num_heads, dec_length, dec_length] + encdec_atts: numpy array, encoder-decoder attentions + [num_layers, batch_size, num_heads, enc_length, dec_length] + + Returns: + Dictionary of attention representations with the structure: + { + 'all': Representations for showing all attentions at the same time. + 'inp_inp': Representations for showing encoder self-attentions + 'inp_out': Representations for showing encoder-decoder attentions + 'out_out': Representations for showing decoder self-attentions + } + and each sub-dictionary has structure: + { + 'att': list of inter attentions matrices, one for each attention head + 'top_text': list of strings, words to be displayed on the left of the vis + 'bot_text': list of strings, words to be displayed on the right of the vis + } + """ + def get_full_attention(layer): + """Get the full input+output - input+output attentions.""" + enc_att = enc_atts[layer][0], + dec_att = dec_atts[layer][0], + encdec_att = encdec_atts[layer][0] + enc_att = np.transpose(enc_att, [0, 2, 1]) + dec_att = np.transpose(dec_att, [0, 2, 1]) + encdec_att = np.transpose(encdec_att, [0, 2, 1]) + # [heads, query_length, memory_length] + enc_length = enc_att.shape[1] + dec_length = dec_att.shape[1] + num_heads = enc_att.shape[0] + first = np.concatenate([enc_att, encdec_att], axis=2) + second = np.concatenate( + [np.zeros((num_heads, dec_length, enc_length)), dec_att], axis=2) + full_att = np.concatenate([first, second], axis=1) + return [ha.T.tolist() for ha in full_att] + + def get_inp_inp_attention(layer): + att = np.transpose(enc_atts[layer][0], (0, 2, 1)) + return [ha.T.tolist() for ha in att] + + def get_out_inp_attention(layer): + att = np.transpose(encdec_atts[layer][0], (0, 2, 1)) + return [ha.T.tolist() for ha in att] + + def get_out_out_attention(layer): + att = np.transpose(dec_atts[layer][0], (0, 2, 1)) + return [ha.T.tolist() for ha in att] + + def get_attentions(get_attention_fn): + num_layers = len(enc_atts) + attentions = [] + for i in range(num_layers): + attentions.append(get_attention_fn(i)) + + return attentions + + attentions = { + 'all': { + 'att': get_attentions(get_full_attention), + 'top_text': inp_text + out_text, + 'bot_text': inp_text + out_text, + }, + 'inp_inp': { + 'att': get_attentions(get_inp_inp_attention), + 'top_text': inp_text, + 'bot_text': inp_text, + }, + 'inp_out': { + 'att': get_attentions(get_out_inp_attention), + 'top_text': inp_text, + 'bot_text': out_text, + }, + 'out_out': { + 'att': get_attentions(get_out_out_attention), + 'top_text': out_text, + 'bot_text': out_text, + }, + } + + return attentions