diff --git a/.travis.yml b/.travis.yml index 02e7e0768..2ae1acf65 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,12 +11,15 @@ env: - TF_VERSION="1.5.*" - TF_VERSION="1.6.*" - TF_VERSION="1.7.*" + - TF_VERSION="1.8.*" matrix: exclude: - python: "3.6" env: TF_VERSION="1.5.*" - python: "3.6" env: TF_VERSION="1.6.*" + - python: "3.6" + env: TF_VERSION="1.7.*" before_install: - echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list - curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add - diff --git a/README.md b/README.md index 31b25562f..72c29c8d0 100644 --- a/README.md +++ b/README.md @@ -26,10 +26,10 @@ You can chat with us on ### Quick Start -[This iPython notebook](https://goo.gl/wkHexj) explains T2T and runs in your -browser using a free VM from Google, no installation needed. -Alternatively, here is a one-command version that installs T2T, downloads MNIST, -trains a model and evaluates it: +[This iPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb) +explains T2T and runs in your browser using a free VM from Google, +no installation needed. Alternatively, here is a one-command version that +installs T2T, downloads MNIST, trains a model and evaluates it: ``` pip install tensor2tensor && t2t-trainer \ diff --git a/docs/index.md b/docs/index.md index 9262461c7..2ffbb956d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -19,7 +19,7 @@ research](https://research.googleblog.com/2017/06/accelerating-deep-learning-res ## Basics * [Walkthrough](walkthrough.md): Install and run. -* [IPython notebook](https://goo.gl/wkHexj): Get a hands-on experience. +* [IPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb): Get a hands-on experience. * [Overview](overview.md): How all parts of T2T code are connected. * [New Problem](new_problem.md): Train T2T models on your data. * [New Model](new_model.md): Create your own T2T model. diff --git a/docs/walkthrough.md b/docs/walkthrough.md index 31b25562f..72c29c8d0 100644 --- a/docs/walkthrough.md +++ b/docs/walkthrough.md @@ -26,10 +26,10 @@ You can chat with us on ### Quick Start -[This iPython notebook](https://goo.gl/wkHexj) explains T2T and runs in your -browser using a free VM from Google, no installation needed. -Alternatively, here is a one-command version that installs T2T, downloads MNIST, -trains a model and evaluates it: +[This iPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb) +explains T2T and runs in your browser using a free VM from Google, +no installation needed. Alternatively, here is a one-command version that +installs T2T, downloads MNIST, trains a model and evaluates it: ``` pip install tensor2tensor && t2t-trainer \ diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 7e3b3e008..87ce4df86 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -46,6 +46,7 @@ "tensor2tensor.data_generators.ptb", "tensor2tensor.data_generators.snli", "tensor2tensor.data_generators.squad", + "tensor2tensor.data_generators.subject_verb_agreement", "tensor2tensor.data_generators.translate_encs", "tensor2tensor.data_generators.translate_ende", "tensor2tensor.data_generators.translate_enet", @@ -56,6 +57,7 @@ "tensor2tensor.data_generators.twentybn", "tensor2tensor.data_generators.wiki", "tensor2tensor.data_generators.wikisum.wikisum", + "tensor2tensor.data_generators.wikitext103", "tensor2tensor.data_generators.wsj_parsing", ] diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index e6ab96c04..713ecce69 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -44,7 +44,7 @@ def to_example(dictionary): features = {} for (k, v) in six.iteritems(dictionary): if not v: - raise ValueError("Empty generated field: %s", str((k, v))) + raise ValueError("Empty generated field: %s" % str((k, v))) if isinstance(v[0], six.integer_types): features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v)) elif isinstance(v[0], float): @@ -130,7 +130,8 @@ def outputs_exist(filenames): return out_fname -def generate_files(generator, output_filenames, max_cases=None): +def generate_files(generator, output_filenames, + max_cases=None, cycle_every_n=1): """Generate cases from a generator and save as TFRecord files. Generated cases are transformed to tf.Example protos and saved as TFRecords @@ -141,6 +142,8 @@ def generate_files(generator, output_filenames, max_cases=None): output_filenames: List of output file paths. max_cases: maximum number of cases to get from the generator; if None (default), we use the generator until StopIteration is raised. + cycle_every_n: how many cases from the generator to take before + switching to the next shard; by default set to 1, switch every case. """ if outputs_exist(output_filenames): tf.logging.info("Skipping generator because outputs files exist") @@ -159,7 +162,8 @@ def generate_files(generator, output_filenames, max_cases=None): break example = to_example(case) writers[shard].write(example.SerializeToString()) - shard = (shard + 1) % num_shards + if counter % cycle_every_n == 0: + shard = (shard + 1) % num_shards for writer in writers: writer.close() @@ -341,6 +345,7 @@ def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size, """Generate a vocabulary from the datasets in sources.""" def generate(): + """Generate lines for vocabulary generation.""" tf.logging.info("Generating vocab from: %s", str(sources)) for source in sources: url = source[0] diff --git a/tensor2tensor/data_generators/gym.py b/tensor2tensor/data_generators/gym.py index 3c70e9c49..9ea820233 100644 --- a/tensor2tensor/data_generators/gym.py +++ b/tensor2tensor/data_generators/gym.py @@ -18,10 +18,11 @@ from __future__ import division from __future__ import print_function -from collections import deque - import functools +import os + # Dependency imports + import gym from tensor2tensor.data_generators import problem @@ -62,9 +63,7 @@ def num_target_frames(self): return 1 def eval_metrics(self): - eval_metrics = [ - metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ, - metrics.Metrics.NEG_LOG_PERPLEXITY] + eval_metrics = [metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ] return eval_metrics @property @@ -108,6 +107,10 @@ def num_rewards(self): def num_steps(self): raise NotImplementedError() + @property + def total_number_of_frames(self): + return self.num_steps + @property def min_reward(self): raise NotImplementedError() @@ -126,13 +129,13 @@ def hparams(self, defaults, unused_model_hparams): p.target_space_id = problem.SpaceID.IMAGE def generate_samples(self, data_dir, tmp_dir, unused_dataset_split): - next_obs = self.env.reset() + next_observation = self.env.reset() for _ in range(self.num_steps): - observation = next_obs + observation = next_observation action = self.get_action(observation) - next_obs, reward, done, _ = self.env.step(action) + next_observation, reward, done, _ = self.env.step(action) if done: - next_obs = self.env.reset() + next_observation = self.env.reset() yield {"frame": observation, "action": [action], "done": [done], @@ -184,23 +187,22 @@ class GymDiscreteProblemWithAgent(GymPongRandom5k): def __init__(self, *args, **kwargs): super(GymDiscreteProblemWithAgent, self).__init__(*args, **kwargs) self._env = None - self.history_size = 2 + self.debug_dump_frames_path = "debug_frames_env" # defaults self.environment_spec = lambda: gym.make("PongDeterministic-v4") - self.in_graph_wrappers = [(atari.MaxAndSkipWrapper, {"skip": 4})] + self.in_graph_wrappers = [] self.collect_hparams = rl.atari_base() - self.settable_num_steps = 1000 + self.settable_num_steps = 20000 self.simulated_environment = None - self.warm_up = 70 + self.warm_up = 10 @property def num_steps(self): return self.settable_num_steps def _setup(self): - in_graph_wrappers = [(atari.ShiftRewardWrapper, {"add_value": 2}), - (atari.MemoryWrapper, {})] + self.in_graph_wrappers + in_graph_wrappers = [(atari.MemoryWrapper, {})] + self.in_graph_wrappers env_hparams = tf.contrib.training.HParams( in_graph_wrappers=in_graph_wrappers, simulated_environment=self.simulated_environment) @@ -229,41 +231,41 @@ def _setup(self): self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size() self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue() - self.history_buffer = deque(maxlen=self.history_size+1) def restore_networks(self, sess): if FLAGS.agent_policy_path: model_saver = tf.train.Saver( - tf.global_variables(".*network_parameters.*")) + tf.global_variables(".*network_parameters.*")) model_saver.restore(sess, FLAGS.agent_policy_path) def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split): self._setup() + self.debug_dump_frames_path = os.path.join( + data_dir, self.debug_dump_frames_path) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) self.restore_networks(sess) - + # Actions are shifted by 1 by MemoryWrapper, compensate here. + avilable_data_size = sess.run(self.avilable_data_size_op) + if avilable_data_size < 1: + sess.run(self.collect_trigger_op) pieces_generated = 0 + observ, reward, _, _ = sess.run(self.data_get_op) while pieces_generated < self.num_steps + self.warm_up: avilable_data_size = sess.run(self.avilable_data_size_op) - if avilable_data_size > 0: - observ, reward, action, _ = sess.run(self.data_get_op) - self.history_buffer.append(observ) - - if len(self.history_buffer) == self.history_size + 1: - pieces_generated += 1 - ret_dict = {"image/encoded": [observ], - "image/format": ["png"], - "image/height": [self.frame_height], - "image/width": [self.frame_width], - "action": [int(action)], - "done": [int(False)], - "reward": [int(reward) - self.min_reward]} - if pieces_generated > self.warm_up: - yield ret_dict - else: + if avilable_data_size < 1: sess.run(self.collect_trigger_op) + next_observ, next_reward, action, _ = sess.run(self.data_get_op) + yield {"image/encoded": [observ], + "image/format": ["png"], + "image/height": [self.frame_height], + "image/width": [self.frame_width], + "action": [int(action)], + "done": [int(False)], + "reward": [int(reward) - self.min_reward]} + pieces_generated += 1 + observ, reward = next_observ, next_reward @registry.register_problem @@ -273,7 +275,7 @@ class GymSimulatedDiscreteProblemWithAgent(GymDiscreteProblemWithAgent): def __init__(self, *args, **kwargs): super(GymSimulatedDiscreteProblemWithAgent, self).__init__(*args, **kwargs) self.simulated_environment = True - self.debug_dump_frames_path = "/tmp/t2t_debug_dump_frames" + self.debug_dump_frames_path = "debug_frames_sim" def restore_networks(self, sess): super(GymSimulatedDiscreteProblemWithAgent, self).restore_networks(sess) diff --git a/tensor2tensor/data_generators/imagenet.py b/tensor2tensor/data_generators/imagenet.py index 109d37c5d..06206ce6f 100644 --- a/tensor2tensor/data_generators/imagenet.py +++ b/tensor2tensor/data_generators/imagenet.py @@ -189,7 +189,7 @@ def preprocess_example(self, example, mode, _): @registry.register_problem class ImageImagenet64Gen(ImageImagenet): - """Cifar-10 Tune.""" + """Imagenet 64 from the pixen cnn paper""" @property def train_shards(self): @@ -264,6 +264,33 @@ def preprocess_example(self, example, mode, hparams): return example +@registry.register_problem +class ImageImagenet32Small(ImageImagenet): + """Imagenet small from the pixel cnn paper""" + + @property + def is_small(self): + return False # Modalities like for CIFAR. + + @property + def num_classes(self): + return 1000 + + @property + def train_shards(self): + return 1024 + + @property + def dev_shards(self): + return 10 + + def preprocess_example(self, example, mode, unused_hparams): + example["inputs"].set_shape([_IMAGENET_SMALL_IMAGE_SIZE, + _IMAGENET_SMALL_IMAGE_SIZE, 3]) + example["inputs"] = tf.to_int64(example["inputs"]) + return example + + @registry.register_problem class ImageImagenet64(ImageImagenet32): """Imagenet rescaled to 64x64.""" diff --git a/tensor2tensor/data_generators/squad.py b/tensor2tensor/data_generators/squad.py index 178a3a4f4..e19307242 100644 --- a/tensor2tensor/data_generators/squad.py +++ b/tensor2tensor/data_generators/squad.py @@ -143,5 +143,5 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split): for sample in samples: sample['targets'] = self.generate_targets(sample['targets'], sample['context']) - if not sample['targets']: + if sample['targets']: yield sample diff --git a/tensor2tensor/data_generators/subject_verb_agreement.py b/tensor2tensor/data_generators/subject_verb_agreement.py new file mode 100644 index 000000000..587ff7da6 --- /dev/null +++ b/tensor2tensor/data_generators/subject_verb_agreement.py @@ -0,0 +1,258 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data generators for subject-verb agreement dataset. + +https://arxiv.org/pdf/1611.01368.pdf + +Based on he main paper, predicting verb's number can be done in two setups: +- Language Modeling +- Binary Classification + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from collections import defaultdict +import csv +import gzip +import os +import random + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import text_problems +from tensor2tensor.utils import metrics +from tensor2tensor.utils import registry + +import tensorflow as tf + +_FILE_NAME = 'agr_50_mostcommon_10K' +_TAR = _FILE_NAME + '.tsv.gz' +_URL = 'http://tallinzen.net/media/rnn_agreement/' + _TAR +_LABEL_DICT = {'VBZ': 0, 'VBP': 1} + + +def _build_vocab(examples, example_field, vocab_dir, vocab_name): + """Build a vocabulary from examples. + + Args: + examples: a dict containing all the examples. + example_field: field of example from which the vocabulary is built. + vocab_dir: directory where to save the vocabulary. + vocab_name: vocab file name. + + Returns: + text encoder. + """ + vocab_path = os.path.join(vocab_dir, vocab_name) + if not tf.gfile.Exists(vocab_path): + data = [] + for e in examples: + data.extend(e[example_field].split()) + counter = collections.Counter(data) + count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) + words, _ = list(zip(*count_pairs)) + encoder = text_encoder.TokenTextEncoder(None, vocab_list=words) + encoder.store_to_file(vocab_path) + else: + encoder = text_encoder.TokenTextEncoder(vocab_path) + return encoder + + +def load_examples(tmp_dir, equalize_classes=False): + """Loads exampls from the tsv file. + + Args: + tmp_dir: temp directory. + equalize_classes: if equalize number of examples in the classes. + + Returns: + All examples in the dataset. + + """ + + infile = generator_utils.maybe_download(tmp_dir, _TAR, _URL) + tf.logging.info('Loading examples') + + all_examples = [] + for i, d in enumerate(csv.DictReader(gzip.open(infile), delimiter='\t')): + if i % 100000 == 0: + tf.logging.info('%d examples have been loaded....' % i) + ex = {x: int(y) if y.isdigit() else y for x, y in d.items()} + all_examples.append(ex) + + classes = defaultdict(list) + for ex in all_examples: + classes[ex['verb_pos']].append(ex) + + del all_examples[:] + assert len(classes) == 2 + + c1 = classes.values()[0] + c2 = classes.values()[1] + random.seed(1) + random.shuffle(c1) + random.shuffle(c2) + if equalize_classes: + l = min(len(c1), len(c2)) + all_examples = c1[:l] + c2[:l] + else: + all_examples = c1 + c2 + random.shuffle(all_examples) + + return all_examples + + +@registry.register_problem +class SvaNumberPrediction(text_problems.Text2ClassProblem): + """Subject verb agreement as verb number predicion (binary classification).""" + + @property + def is_generate_per_split(self): + # generate_data will shard the data into TRAIN and EVAL for us. + return False + + @property + def dataset_splits(self): + """Splits of data to produce and number of output shards for each. + + This is the setup of the main paper. 10% train/ 90% eval + + Returns: + A dict containing splits information. + + """ + return [{ + 'split': problem.DatasetSplit.TRAIN, + 'shards': 1, + }, { + 'split': problem.DatasetSplit.EVAL, + 'shards': 9, + }] + + @property + def vocab_type(self): + return text_problems.VocabType.TOKEN + + @property + def num_classes(self): + return 2 + + def class_labels(self, data_dir): + """Class labels.""" + del data_dir + return ['VBZ', 'VBP'] + + def generate_samples(self, data_dir, tmp_dir, dataset_split): + """Generate samples of text and label pairs. + + Each yielded dict will be a single example. The inputs should be raw text. + The label should be an int in [0, self.num_classes). + + Args: + data_dir: final data directory. Typically only used in this method to copy + over user-supplied vocab files (for example, if vocab_type == + VocabType.TOKEN). + tmp_dir: temporary directory that you can use for downloading and scratch. + dataset_split: problem.DatasetSplit, which data split to generate samples + for (for example, training and evaluation). + + Returns: + sample generator. + """ + example_filed = 'sentence' + examples = load_examples(tmp_dir) + _build_vocab(examples, example_filed, data_dir, self.vocab_filename) + + def _generate_samples(): + for example in examples: + index = int(example['verb_index']) - 1 + inputs = example[example_filed].split()[:index] + yield { + 'inputs': ' '.join(inputs), + 'label': _LABEL_DICT[example['verb_pos']] + } + + return _generate_samples() + + def eval_metrics(self): + """Specify the set of evaluation metrics for this problem. + + Returns: + List of evaluation metrics of interest. + """ + return [metrics.Metrics.ACC] + + +@registry.register_problem +class SvaLanguageModeling(text_problems.Text2SelfProblem): + """Subject verb agreement as language modeling task.""" + + @property + def is_generate_per_split(self): + # generate_data will shard the data into TRAIN and EVAL for us. + return False + + @property + def dataset_splits(self): + """Splits of data to produce and number of output shards for each. + + This is the setup of the main paper. 10% train/ 90% eval + + Returns: + A dict containing splits information. + + """ + return [{ + 'split': problem.DatasetSplit.TRAIN, + 'shards': 1, + }, { + 'split': problem.DatasetSplit.EVAL, + 'shards': 9, + }] + + @property + def vocab_type(self): + return text_problems.VocabType.TOKEN + + def generate_samples(self, data_dir, tmp_dir, dataset_split): + """Generates samples. + + Args: + data_dir: data directory + tmp_dir: temp directory + dataset_split: dataset split + + Returns: + sample generator. + + """ + example_filed = 'sentence' + examples = load_examples(tmp_dir) + _build_vocab(examples, example_filed, data_dir, self.vocab_filename) + + def _generate_samples(): + for example in examples: + index = int(example['verb_index']) - 1 + targets = example[example_filed].split()[:index + 1] + yield {'targets': ' '.join(targets)} + + return _generate_samples() diff --git a/tensor2tensor/data_generators/video_utils.py b/tensor2tensor/data_generators/video_utils.py index db965795d..a5f35db75 100644 --- a/tensor2tensor/data_generators/video_utils.py +++ b/tensor2tensor/data_generators/video_utils.py @@ -67,6 +67,11 @@ def frame_width(self): """Width of each frame.""" raise NotImplementedError + @property + def total_number_of_frames(self): + """The total number of frames, needed for sharding.""" + raise NotImplementedError + @property def num_input_frames(self): """Number of frames to batch on one input.""" @@ -188,7 +193,9 @@ def _preprocess(example): preprocessed_dataset = dataset.map(_preprocess) num_frames = self.num_input_frames + self.num_target_frames - # TODO(lukaszkaiser): should jump by a random position at the beginning. + # We jump by a random position at the beginning to add variety. + random_skip = tf.random_uniform([], maxval=num_frames, dtype=tf.int64) + preprocessed_dataset = preprocessed_dataset.skip(random_skip) batch_dataset = preprocessed_dataset.apply( tf.contrib.data.batch_and_drop_remainder(num_frames)) dataset = batch_dataset.map(features_from_batch).shuffle(8) @@ -265,6 +272,8 @@ def generate_encoded_samples_debug(self, data_dir, tmp_dir, dataset_split): for sample in self.generate_encoded_samples( data_dir, tmp_dir, dataset_split): if self.debug_dump_frames_path: + if not tf.gfile.Exists(self.debug_dump_frames_path): + tf.gfile.MkDir(self.debug_dump_frames_path) path = os.path.join(self.debug_dump_frames_path, "frame_%05d.png" % counter) with tf.gfile.Open(path, "wb") as f: @@ -296,7 +305,9 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): else: generator_utils.generate_files( self.generate_encoded_samples_debug( - data_dir, tmp_dir, problem.DatasetSplit.TRAIN), all_paths) + data_dir, tmp_dir, problem.DatasetSplit.TRAIN), + all_paths, + cycle_every_n=self.total_number_of_frames // len(all_paths)) # TODO(lukaszkaiser): remove this version after everything is ported. diff --git a/tensor2tensor/data_generators/wikisum/README.md b/tensor2tensor/data_generators/wikisum/README.md index 4713a8d37..20287a760 100644 --- a/tensor2tensor/data_generators/wikisum/README.md +++ b/tensor2tensor/data_generators/wikisum/README.md @@ -116,10 +116,10 @@ Pricing is taken from * `WikisumCommoncrawl` * `get_references_commoncrawl`: $50 (1k machines, 1 CPU, 2G memory, 1 hour) - * `produce_examples`: $350 (1k machines, 1 CPU, 2G memory, 8 hours) + * `produce_examples`: $25 (1k machines, 1 CPU, 2G memory, 30 minutes) * `WikisumWeb` - * `get_references_web`: $750 (1k machines, 4 CPU, 4G memory, 5 hours) - * `produce_examples`: $350 (1k machines, 1 CPU, 2G memory, 8 hours) + * `get_references_web`: $600 (1k machines, 4 CPU, 4G memory, 4 hours) + * `produce_examples`: $25 (1k machines, 1 CPU, 2G memory, 30 minutes) ## Commands to generate `WikisumCommoncrawl` @@ -130,11 +130,11 @@ pip install tensor2tensor -U --user BUCKET=gs://my-gcs-bucket/wikisum_commoncrawl # Extract references from CommonCrawl -python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ +python -m tensor2tensor.data_generators.wikisum.parallel_launch \ --num_instances=1000 \ --cpu=1 --mem=2 \ - --name=wikisum-refs-cc \ - --log_dir=$BUCKET/refs_logs \ + --name=wikisum-cc-refs \ + --log_dir=$BUCKET/logs \ --setup_command="pip install tensor2tensor tensorflow -U -q --user" \ --command_prefix="python -m tensor2tensor.data_generators.wikisum.get_references_commoncrawl --num_tasks=1000 --out_dir=$BUCKET/wiki_references --task_id" @@ -145,13 +145,13 @@ python -m tensor2tensor.data_generators.wikisum.generate_vocab \ --for_commoncrawl # Produce examples -python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ +python -m tensor2tensor.data_generators.wikisum.parallel_launch \ --num_instances=1000 \ --cpu=1 --mem=2 \ --name=wikisum-cc-produce \ - --log_dir=$BUCKET/produce_logs \ + --log_dir=$BUCKET/logs \ --setup_command="pip install tensor2tensor tensorflow -U -q --user" \ - --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples.py --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --for_commoncrawl --task_id" + --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --for_commoncrawl --task_id" ``` ## Commands to generate `WikisumWeb` @@ -163,11 +163,11 @@ pip install tensor2tensor -U --user BUCKET=gs://my-gcs-bucket/wikisum_web # Fetch references from web -python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ +python -m tensor2tensor.data_generators.wikisum.parallel_launch \ --num_instances=1000 \ --cpu=4 --mem=4 \ - --name=wikisum-refs-web \ - --log_dir=$BUCKET/refs_logs \ + --name=wikisum-web-refs \ + --log_dir=$BUCKET/logs \ --setup_command="pip3 install tensorflow tensor2tensor aiohttp cchardet aiodns bs4 -U -q --user" \ --command_prefix="python3 wikisum/get_references_web.py --out_dir=$BUCKET/wiki_references --shard_id" @@ -177,13 +177,13 @@ python -m tensor2tensor.data_generators.wikisum.generate_vocab \ --refs_dir=$BUCKET/wiki_references # Produce examples -python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ +python -m tensor2tensor.data_generators.wikisum.parallel_launch \ --num_instances=1000 \ --cpu=1 --mem=2 \ --name=wikisum-web-produce \ - --log_dir=$BUCKET/produce_logs \ + --log_dir=$BUCKET/logs \ --setup_command="pip install tensor2tensor tensorflow -U -q --user" \ - --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples.py --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --task_id" + --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --task_id" ``` ## Training diff --git a/tensor2tensor/data_generators/wikisum/generate_vocab.py b/tensor2tensor/data_generators/wikisum/generate_vocab.py index b8e64702f..33fbdbb73 100644 --- a/tensor2tensor/data_generators/wikisum/generate_vocab.py +++ b/tensor2tensor/data_generators/wikisum/generate_vocab.py @@ -43,4 +43,5 @@ def main(_): if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) tf.app.run() diff --git a/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py b/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py index ca1717c73..5198fec5a 100644 --- a/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py +++ b/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py @@ -206,7 +206,7 @@ def write_ref_content(url, ref, f): async def fetch_url(url, session, side_data): text = None try: - async with session.get(url, timeout=30, verify_ssl=False) as response: + async with session.get(url, timeout=10, verify_ssl=False) as response: if response.status == 200: text = await response.text() else: diff --git a/tensor2tensor/data_generators/wikisum/parallel_launch.py b/tensor2tensor/data_generators/wikisum/parallel_launch.py index 9cec88d07..f398218b0 100644 --- a/tensor2tensor/data_generators/wikisum/parallel_launch.py +++ b/tensor2tensor/data_generators/wikisum/parallel_launch.py @@ -44,8 +44,10 @@ from __future__ import division from __future__ import print_function +import contextlib import multiprocessing as mp import os +import socket import subprocess as sp import time @@ -88,12 +90,12 @@ COPY_CODE = "gcloud compute scp --recurse {local_dir} {instance_name}:~/" SSH = "gcloud compute ssh {instance_name} --command" SCREEN = "screen -dmS test bash -c \"{command}\"" -SSH_CHECK = "nc -w 1 -z {ip} 22" DEFAULT_ZONE = "gcloud config get-value compute/zone" LOGS = "> ~/logs-{task_id}.txt 2>&1; gsutil cp ~/logs-{task_id}.txt {bucket}" def remote_run(cmd, instance_name, detach=False, retries=1): + """Run command on GCS instance, optionally detached.""" if detach: cmd = SCREEN.format(command=cmd) args = SSH.format(instance_name=instance_name).split() @@ -112,19 +114,27 @@ def default_zone(): return cloud.shell_output(DEFAULT_ZONE).strip() +@contextlib.contextmanager +def safe_socket(timeout=2): + s = socket.socket() + s.settimeout(timeout) + try: + yield s + finally: + s.close() + + def wait_for_ssh(ip): """Wait for SSH to be available at given IP address.""" - i = 0 - while True: - try: - cloud.shell_run(SSH_CHECK, ip=ip) - break - except sp.CalledProcessError: - if i > 12: # ~2m - return False - time.sleep(10) - i += 1 - return True + for _ in range(12): + with safe_socket() as s: + try: + s.connect((ip, 22)) + return True + except socket.timeout: + pass + time.sleep(10) + return False def create_instance(instance_name, cpu=1, mem=4): @@ -206,18 +216,22 @@ def main(_): assert len(suffixes) == FLAGS.num_instances vm_info = list_vm_names_and_ips() - vm_names = zip(*vm_info)[0] if vm_info else [] + vm_names = list(zip(*vm_info))[0] if vm_info else [] pool = mp.Pool(FLAGS.num_threads) async_results = [] - log_dir = None - if FLAGS.log_dir: - log_dir = os.path.join(FLAGS.log_dir, FLAGS.name) - tf.gfile.MakeDirs(log_dir) - assert log_dir.startswith("gs://") - if not log_dir.endswith("/"): - log_dir += "/" + assert FLAGS.log_dir + log_dir = os.path.join(FLAGS.log_dir, FLAGS.name) + tf.gfile.MakeDirs(log_dir) + assert log_dir.startswith("gs://") + if not log_dir.endswith("/"): + log_dir += "/" + # Write a test file to make sure gcloud GCS APIs are enabled + test_filename = os.path.join(log_dir, "check_write") + with tf.gfile.Open(test_filename, "w") as f: + f.write("testing GCS write") + tf.gfile.Remove(test_filename) instance_ids = list(range(FLAGS.num_instances)) if FLAGS.instance_ids: @@ -242,23 +256,25 @@ def main(_): FLAGS.cpu, FLAGS.mem, code_dir, FLAGS.setup_command) res = pool.apply_async(launch_instance, args) - async_results.append(res) + async_results.append((res, instance_name, i)) failed = [] - for i, res in enumerate(async_results): + for res, instance_name, i in async_results: try: res.get() - except: # pylint: disable=bare-except - failed.append(i) - tf.logging.error("Failed to launch task %d", i) + except Exception as e: # pylint: disable=broad-except + failed.append((instance_name, i)) + tf.logging.error("Failed to launch task %s due to exception %s", + instance_name, str(e)) results = [] if failed: - tf.logging.error("Failed to launch %d jobs. Task ids: %s. " - "Attempting delete in case they are still up.", - len(failed), str(failed)) - for i in failed: - instance_name = "%s-%d" % (FLAGS.name, i) + ids_for_flag = ",".join([str(i) for i in list(zip(*failed))[1]]) + tf.logging.error("Failed to launch %d jobs. Tasks: %s. " + "Attempting delete in case they are still up. Rerun with " + "--instance_ids='%s' to attempt relaunch.", + len(failed), str(failed), ids_for_flag) + for instance_name, _ in failed: res = pool.apply_async(delete_instance, (instance_name,)) results.append(res) diff --git a/tensor2tensor/data_generators/wikisum/wikisum.py b/tensor2tensor/data_generators/wikisum/wikisum.py index 2dac504d6..c0664f328 100644 --- a/tensor2tensor/data_generators/wikisum/wikisum.py +++ b/tensor2tensor/data_generators/wikisum/wikisum.py @@ -278,7 +278,7 @@ def _parse_example(ex_ser): except tf.errors.OutOfRangeError: break - data[ex["url"]] = ex["content"] + data[ex["url"]] = text_encoder.to_unicode(ex["content"]) i += 1 return data @@ -339,11 +339,14 @@ def _parse_example(ex_ser): break sections = [ - WikipediaSection(title=title, text=text) + WikipediaSection(title=text_encoder.to_unicode(title), + text=text_encoder.to_unicode(text)) for title, text in zip(ex["section_titles"], ex["section_texts"]) ] yield WikipediaArticle( - url=ex["url"], title=ex["title"], sections=sections) + url=text_encoder.to_unicode(ex["url"]), + title=text_encoder.to_unicode(ex["title"]), + sections=sections) def _token_counts(text, token_set=None): diff --git a/tensor2tensor/data_generators/wikitext103.py b/tensor2tensor/data_generators/wikitext103.py new file mode 100644 index 000000000..5e1d4f310 --- /dev/null +++ b/tensor2tensor/data_generators/wikitext103.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data generators for wikitext-103. + +Wikitext-103: Long term dependency language modeling dataset +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import zipfile + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import text_problems +from tensor2tensor.utils import registry + +import tensorflow as tf + + +def _build_vocab(filename, vocab_dir, vocab_name): + """Reads a file to build a vocabulary. + + Args: + filename: file to read list of words from. + vocab_dir: directory where to save the vocabulary. + vocab_name: vocab file name. + + Returns: + text encoder. + """ + vocab_path = os.path.join(vocab_dir, vocab_name) + if not tf.gfile.Exists(vocab_path): + with tf.gfile.GFile(filename, "r") as f: + data = f.read().split() + counter = collections.Counter(data) + count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) + words, _ = list(zip(*count_pairs)) + encoder = text_encoder.TokenTextEncoder(None, vocab_list=words) + encoder.store_to_file(vocab_path) + else: + encoder = text_encoder.TokenTextEncoder(vocab_path) + return encoder + + +def _maybe_download_corpus(tmp_dir, vocab_type): + """Download and unpack the corpus. + + Args: + tmp_dir: directory containing dataset. + vocab_type: which vocabulary are we using. + + Returns: + The list of names of files. + """ + if vocab_type == text_problems.VocabType.CHARACTER: + + dataset_url = ("https://s3.amazonaws.com/research.metamind.io/wikitext" + "/wikitext-103-raw-v1.zip") + dir_name = "wikitext-103-raw" + else: + dataset_url = ("https://s3.amazonaws.com/research.metamind.io/wikitext" + "/wikitext-103-v1.zip") + dir_name = "wikitext-103" + + fname = os.path.basename(dataset_url) + compressed_filepath = generator_utils.maybe_download(tmp_dir, fname, + dataset_url) + zip_ref = zipfile.ZipFile(compressed_filepath, "r") + zip_ref.extractall(tmp_dir) + zip_ref.close() + + files = os.path.join(tmp_dir, dir_name, "*") + train_file, valid_file, test_file = None, None, None + for f in tf.gfile.Glob(files): + fname = os.path.basename(f) + if "train" in fname: + train_file = f + elif "valid" in fname: + valid_file = f + elif "test" in fname: + test_file = f + + assert train_file, "Training file not found" + assert valid_file, "Validation file not found" + assert test_file, "Testing file not found" + + return train_file, valid_file, test_file + + +@registry.register_problem +class LanguagemodelWikitext103(text_problems.Text2SelfProblem): + """Wikitext103 dataset token-level.""" + + @property + def dataset_splits(self): + return [{ + "split": problem.DatasetSplit.TRAIN, + "shards": 10, + }, { + "split": problem.DatasetSplit.EVAL, + "shards": 1, + }, { + "split": problem.DatasetSplit.TEST, + "shards": 1, + }] + + @property + def is_generate_per_split(self): + return True + + @property + def vocab_type(self): + return text_problems.VocabType.TOKEN + + def generate_samples(self, data_dir, tmp_dir, dataset_split): + train_file, valid_file, test_file = _maybe_download_corpus( + tmp_dir, self.vocab_type) + + if dataset_split == problem.DatasetSplit.TRAIN: + filepath = train_file + if self.vocab_type == text_problems.VocabType.TOKEN: + _build_vocab(train_file, data_dir, self.vocab_filename) + + elif dataset_split == problem.DatasetSplit.EVAL: + filepath = valid_file + + elif dataset_split == problem.DatasetSplit.TEST: + filepath = test_file + + def _generate_samples(): + with tf.gfile.GFile(filepath, "r") as f: + for line in f: + line = " ".join(line.strip().split()) + if line: + yield {"targets": line} + + return _generate_samples() + + +@registry.register_problem +class LanguagemodelWikitext103Characters(LanguagemodelWikitext103): + """Wikitext-103, character-level.""" + + @property + def vocab_type(self): + return text_problems.VocabType.CHARACTER diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 1545f477f..e94b470de 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -387,7 +387,8 @@ def mse_loss(expected_logits, actual_weights): def get_timing_signal_1d(length, channels, min_timescale=1.0, - max_timescale=1.0e4): + max_timescale=1.0e4, + start_index=0): """Gets a bunch of sinusoids of different frequencies. Each channel of the input Tensor is incremented by a sinusoid of a different @@ -413,11 +414,12 @@ def get_timing_signal_1d(length, different timescales is equal to channels / 2. min_timescale: a float max_timescale: a float + start_index: index of first position Returns: a Tensor of timing signals [1, length, channels] """ - position = tf.to_float(tf.range(length)) + position = tf.to_float(tf.range(length) + start_index) num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / @@ -432,7 +434,10 @@ def get_timing_signal_1d(length, @expert_utils.add_name_scope() -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): +def add_timing_signal_1d(x, + min_timescale=1.0, + max_timescale=1.0e4, + start_index=0): """Adds a bunch of sinusoids of different frequencies to a Tensor. Each channel of the input Tensor is incremented by a sinusoid of a different @@ -456,16 +461,66 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): x: a Tensor with shape [batch, length, channels] min_timescale: a float max_timescale: a float + start_index: index of first position Returns: a Tensor the same shape as x. """ length = common_layers.shape_list(x)[1] channels = common_layers.shape_list(x)[2] - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale, + start_index) return x + signal +@expert_utils.add_name_scope() +def add_layer_timing_signal_learned_1d(x, layer, num_layers): + """Add n-dimensional embedding as the layer (vertical) timing signal. + + Adds embeddings to represent the position of the layer in the tower. + + Args: + x: a tensor with shape [batch, length, depth] + layer: layer num + num_layers: total number of layers + + Returns: + a Tensor the same shape as x. + """ + x_shape = common_layers.shape_list(x) + depth = x_shape[-1] + + shape = [num_layers, 1, 1, depth] + layer_embedding = ( + tf.get_variable( + "layer_embedding", + shape, + initializer=tf.random_normal_initializer(0, depth**-0.5)) * (depth** + 0.5)) + x += layer_embedding[layer, :, :, :] + return x + + +@expert_utils.add_name_scope() +def add_layer_timing_signal_sinusoid_1d(x, layer, num_layers): + """Add sinusoids of different frequencies as layer (vertical) timing signal. + + Args: + x: a Tensor with shape [batch, length, channels] + layer: layer num + num_layers: total number of layers + + Returns: + a Tensor the same shape as x. + """ + + channels = common_layers.shape_list(x)[-1] + signal = get_timing_signal_1d(num_layers, channels) + layer_signal = tf.expand_dims(signal[:, layer, :], axis=1) + + return x + layer_signal + + @expert_utils.add_name_scope() def add_timing_signal_1d_given_position(x, position, @@ -1242,7 +1297,7 @@ def grouped_attention_multihead(query_antecedent, extra_loss *= extra_loss_multiplier # Show a bunch of summaries. - if expert_utils.should_generate_summaries() and make_image_summary: + if common_layers.should_generate_summaries() and make_image_summary: tf.summary.histogram("q_group_size", q_group_size) tf.summary.histogram("m_group_size", m_group_size) tf.summary.scalar("q_loss", q_loss) @@ -1336,7 +1391,7 @@ def dot_product_attention(q, # dropping out the attention links for each of the heads weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) - if expert_utils.should_generate_summaries() and make_image_summary: + if common_layers.should_generate_summaries() and make_image_summary: attention_image_summary(weights, image_shapes) return tf.matmul(weights, v) @@ -1583,7 +1638,7 @@ def dot_product_self_attention_relative_v2(q, # dropping out the attention links for each of the heads weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) - if expert_utils.should_generate_summaries() and make_image_summary: + if common_layers.should_generate_summaries() and make_image_summary: attention_image_summary(weights, image_shapes) ret = tf.matmul(weights, v) # [batch, num_heads, query_length, memory_length] @@ -3796,7 +3851,7 @@ def scaled_dot_product_attention_simple(q, k, v, bias, name=None): if bias is not None: logits += bias weights = tf.nn.softmax(logits, name="attention_weights") - if expert_utils.should_generate_summaries(): + if common_layers.should_generate_summaries(): tf.summary.image( "attention", tf.expand_dims(tf.pow(weights, 0.2), 3), max_outputs=1) return tf.matmul(weights, v) diff --git a/tensor2tensor/layers/common_image_attention.py b/tensor2tensor/layers/common_image_attention.py index 30959398f..ec844b611 100644 --- a/tensor2tensor/layers/common_image_attention.py +++ b/tensor2tensor/layers/common_image_attention.py @@ -284,6 +284,7 @@ def transformer_decoder_layers(inputs, self_attention_bias=None, encoder_decoder_attention_bias=None, attention_type=AttentionType.LOCAL_2D, + losses=None, name="transformer"): """Multi layer transformer.""" x = inputs @@ -335,7 +336,8 @@ def transformer_decoder_layers(inputs, hparams) x = common_layers.layer_postprocess(x, y, hparams) # feed-fwd layers + skip connections - y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) + y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams, + losses=losses) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams) @@ -375,7 +377,7 @@ def transformer_encoder_layers(inputs, return common_layers.layer_preprocess(x, hparams) -def ffn_layer(x, hparams): +def ffn_layer(x, hparams, losses=None): """ffn layer transformer.""" with tf.variable_scope("ffn"): if hparams.ffn_layer == "none": @@ -402,6 +404,23 @@ def ffn_layer(x, hparams): x, hparams.filter_size, hparams.hidden_size, hparams.num_parts, hparams.attention_dropout, hparams.share_kv) y = tf.reshape(y, x_shape) + elif hparams.ffn_layer == "local_moe_tpu": + overhead = (hparams.moe_overhead_train + if hparams.mode == tf.estimator.ModeKeys.TRAIN + else hparams.moe_overhead_eval) + x, x_shape, is_4d = maybe_reshape_4d_to_3d(x) + y, loss = expert_utils.local_moe_tpu( + x, hparams.filter_size // 2, + hparams.hidden_size, + hparams.moe_num_experts, overhead=overhead, + loss_coef=hparams.moe_loss_coef) + if is_4d: + y = tf.reshape(y, x_shape) + if losses is None: + raise ValueError( + "transformer_ffn_layer with type local_moe_tpu must pass in " + "a losses list") + losses.append(loss) else: assert hparams.ffn_layer == "glu_ffn" y = common_layers.gated_linear_unit_layer(x) diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index 1a7c2500b..449e19bda 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -30,18 +30,39 @@ import numpy as np from six.moves import range # pylint: disable=redefined-builtin -from tensor2tensor.utils import expert_utils as eu import tensorflow as tf from tensorflow.python.framework import function from tensorflow.python.framework import ops - # This is a global setting. When turned off, no @function.Defun is used. allow_defun = False +@function.Defun( + python_grad_func=lambda x, dy: tf.convert_to_tensor(dy), + shape_func=lambda op: [op.inputs[0].get_shape()]) +def convert_gradient_to_tensor(x): + """Identity operation whose gradient is converted to a `Tensor`. + + Currently, the gradient to `tf.concat` is particularly expensive to + compute if dy is an `IndexedSlices` (a lack of GPU implementation + forces the gradient operation onto CPU). This situation occurs when + the output of the `tf.concat` is eventually passed to `tf.gather`. + It is sometimes faster to convert the gradient to a `Tensor`, so as + to get the cheaper gradient for `tf.concat`. To do this, replace + `tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`. + + Args: + x: A `Tensor`. + + Returns: + The input `Tensor`. + """ + return x + + def is_on_tpu(): # Support TF versions 1.5+ try: @@ -185,9 +206,10 @@ def convert_rgb_to_real(x): """Conversion of pixel values to real numbers.""" with tf.name_scope("rgb_to_real", values=[x]): x = tf.to_float(x) - # Use the formula (value/128) - 1 to convert each channel value into a - # real number in the range -1 to 1. - x = (x / 128) - 1 + # Use the formula (value/127.5) - 1 to convert each channel value into a + # real number in the range -1 to 1. We use 127.5 instead of 128 because + # the intensities are in the range 0 to 255 + x = (x / 127.5) - 1 return x @@ -230,10 +252,40 @@ def gather(params, indices, dtype=tf.float32): vocab_size = params.get_shape().as_list()[0] indices_flat = tf.reshape(indices, [-1]) out = tf.matmul(tf.one_hot(indices_flat, vocab_size, dtype=dtype), params) - out = eu.reshape_like(out, tf.expand_dims(indices, -1)) + out = reshape_like(out, tf.expand_dims(indices, -1)) return out +# TODO(noam): remove this function after TPUs do cumsum faster. +def cumsum(x, axis=0, exclusive=False): + """TPU hack for tf.cumsum. + + This is equivalent to tf.cumsum and is faster on TPU as of 04/2018 unless + the axis dimension is very large. + + Args: + x: a Tensor + axis: an integer + exclusive: a boolean + Returns: + a Tensor with the same shape as x + """ + if not is_on_tpu(): + return tf.cumsum(x, axis=axis, exclusive=exclusive) + x_shape = shape_list(x) + rank = len(x_shape) + length = x_shape[axis] + my_range = tf.range(length) + comparator = tf.less if exclusive else tf.less_equal + mask = tf.to_float( + comparator(tf.expand_dims(my_range, 1), tf.expand_dims(my_range, 0))) + ret = tf.tensordot(x, mask, axes=[[axis], [0]]) + if axis != rank - 1: + ret = tf.transpose( + ret, list(range(axis)) + [rank - 1] + list(range(axis, rank - 1))) + return ret + + def dropout_no_scaling(x, keep_prob): """Like tf.nn.dropout, but does not scale up. Works on integers also. @@ -267,7 +319,7 @@ def embedding(x, # an indexed-slices to a regular tensor before sending it back to the # parameter server. This avoids excess computation on the parameter server. if not tf.contrib.eager.in_eager_mode(): - embedding_var = eu.convert_gradient_to_tensor(embedding_var) + embedding_var = convert_gradient_to_tensor(embedding_var) x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate) emb_x = gather(embedding_var, x, dtype) if multiplier != 1.0: @@ -1029,7 +1081,7 @@ def simple_attention(target, source, bias=None): if bias is not None: attention += tf.expand_dims(tf.squeeze(bias, axis=[2, 3]), axis=1) attention = tf.nn.softmax(attention) - if eu.should_generate_summaries(): + if should_generate_summaries(): tf.summary.image("attention", tf.expand_dims(attention, 3), max_outputs=5) attended = tf.matmul(attention, source) return tf.reshape(attended, target_shape) @@ -1324,7 +1376,7 @@ def _maybe_transform(t, size, should_transform, name): mask = (1.0 - mask) * -1e9 attention += mask attention = tf.nn.softmax(attention) - if eu.should_generate_summaries(): + if should_generate_summaries(): # Compute a color image summary. image = tf.reshape(attention, [batch, num_heads, target_length, source_length]) @@ -1437,13 +1489,23 @@ def conv_relu_conv(inputs, padding="SAME", nonpadding_mask=None, dropout=0.0, - name=None): + name=None, + cache=None): """Hidden layer with RELU activation followed by linear projection.""" with tf.variable_scope(name, "conv_relu_conv", [inputs]): inputs = maybe_zero_out_padding( inputs, first_kernel_size, nonpadding_mask) + + if cache: + inputs = cache["f"] = tf.concat([cache["f"], inputs], axis=1) + inputs = cache["f"] = inputs[:, -first_kernel_size:, :] + h = tpu_conv1d(inputs, filter_size, first_kernel_size, padding=padding, name="conv1") + + if cache: + h = h[:, -1:, :] + h = tf.nn.relu(h) if dropout != 0.0: h = tf.nn.dropout(h, 1.0 - dropout) @@ -1715,6 +1777,7 @@ def padded_cross_entropy(logits, label_smoothing, weights_fn=weights_nonzero, reduce_sum=True, + cutoff=0.0, gaussian=False): """Compute cross-entropy assuming 0s are padding. @@ -1728,6 +1791,7 @@ def padded_cross_entropy(logits, label_smoothing: a floating point `Scalar`. weights_fn: A function from labels to weights. reduce_sum: a Boolean, whether to sum at the end or not. + cutoff: a float, at which point to have no loss. gaussian: If true, use a Gaussian distribution for label smoothing Returns: @@ -1761,6 +1825,8 @@ def padded_cross_entropy(logits, xent = smoothing_cross_entropy(logits, labels, vocab_size, confidence, gaussian=gaussian) weights = weights_fn(labels) + if cutoff > 0.0: + xent = tf.nn.relu(xent - cutoff) if not reduce_sum: return xent * weights, weights return tf.reduce_sum(xent * weights), tf.reduce_sum(weights) @@ -1894,7 +1960,6 @@ def gated_linear_unit_layer(x, name=None): Returns: x: A tensor """ - with tf.variable_scope( name, default_name="glu_layer", values=[x]): depth = shape_list(x)[-1] @@ -1903,16 +1968,12 @@ def gated_linear_unit_layer(x, name=None): return x * tf.nn.sigmoid(gating_x) -def sru(x, num_layers=2, - activation=None, initial_state=None, name=None, reuse=None): +def sru_with_scan(x, num_layers=2, + activation=None, initial_state=None, name=None, reuse=None): """SRU cell as in https://arxiv.org/abs/1709.02755. - As defined in the paper: - (1) x'_t = W x_t - (2) f_t = sigmoid(Wf x_t + bf) - (3) r_t = sigmoid(Wr x_t + br) - (4) c_t = f_t * c_{t-1} + (1 - f_t) * x'_t - (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t + This implementation uses tf.scan and can incur overhead, see the full SRU + function doc for details and an implementation that is sometimes faster. Args: x: A tensor of shape [batch, ..., channels] ; ... is treated as time. @@ -1962,6 +2023,90 @@ def next_state(cur_state, args_tup): return tf.reshape(x, x_shape) +class CumsumprodCell(object): + """Cumulative sum and product object for use with functional_rnn API.""" + + def __init__(self, initializer): + self._initializer = initializer + + @property + def output_size(self): + return int(shape_list(self._initializer)[-1]) + + def zero_state(self, batch_size, dtype): + dtype = dtype or tf.float32 + return tf.zeros([batch_size, self.output_size], dtype=dtype) + + def __call__(self, inputs_t, state_t): + cur_x_times_one_minus_f, cur_f = tf.split(inputs_t, 2, axis=-1) + state_next = cur_f * state_t + cur_x_times_one_minus_f + outputs_t = state_next + return outputs_t, state_next + + +def sru(x, num_layers=2, + activation=None, initial_state=None, name=None, reuse=None): + """SRU cell as in https://arxiv.org/abs/1709.02755. + + As defined in the paper: + (1) x'_t = W x_t + (2) f_t = sigmoid(Wf x_t + bf) + (3) r_t = sigmoid(Wr x_t + br) + (4) c_t = f_t * c_{t-1} + (1 - f_t) * x'_t + (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t + + This version uses functional ops to be faster on GPUs with TF-1.9+. + + Args: + x: A tensor of shape [batch, ..., channels] ; ... is treated as time. + num_layers: How many SRU layers; default is 2 as results for 1 disappoint. + activation: Optional activation function, try tf.nn.tanh or tf.nn.relu. + initial_state: Optional initial c-state, set to zeros if None. + name: Optional name, "sru" by default. + reuse: Optional reuse. + + Returns: + A tensor of the same shape as x. + + Raises: + ValueError: if num_layers is not positive. + """ + if num_layers < 1: + raise ValueError("Number of layers must be positive: %d" % num_layers) + if is_on_tpu(): # On TPU the XLA does a good job with while. + return sru_with_scan(x, num_layers, activation, initial_state, name, reuse) + try: + from tensorflow.contrib.recurrent.python.ops import functional_rnn # pylint: disable=g-import-not-at-top + except ImportError: + tf.logging.info("functional_rnn not found, using sru_with_scan instead") + return sru_with_scan(x, num_layers, activation, initial_state, name, reuse) + + with tf.variable_scope(name, default_name="sru", values=[x], reuse=reuse): + # We assume x is [batch, ..., channels] and treat all ... as time. + x_shape = shape_list(x) + x = tf.reshape(x, [x_shape[0], -1, x_shape[-1]]) + initial_state = initial_state or tf.zeros([x_shape[0], x_shape[-1]]) + cell = CumsumprodCell(initial_state) + # Calculate SRU on each layer. + for i in range(num_layers): + # The parallel part of the SRU. + x_orig = x + x, f, r = tf.split(tf.layers.dense(x, 3 * x_shape[-1], + name="kernel_%d" % i), 3, axis=-1) + f, r = tf.sigmoid(f), tf.sigmoid(r) + x_times_one_minus_f = x * (1.0 - f) # Compute in parallel for speed. + # Calculate states. + concat = tf.concat([x_times_one_minus_f, f], axis=-1) + c_states, _ = functional_rnn.functional_rnn( + cell, concat, time_major=False) + # Final output. + if activation is not None: + c_states = activation(c_states) + h = c_states * r + (1.0 - r) * x_orig + x = h # Next layer. + return tf.reshape(x, x_shape) + + def linear_set_layer(layer_size, inputs, context=None, @@ -2448,6 +2593,7 @@ def forward_internal(x, f1, f2, scale, bias): @function.Defun(compiled=True) def grad_fn(x, f1, f2, scale, bias, dy): + """Gradient for efficiency.""" with tf.control_dependencies([dy]): num_splits = 4 x_shape = shape_list(x) @@ -2578,130 +2724,6 @@ def reshape_like_all_dims(a, b): return ret -def reduce_by_device(parallelism, data, reduce_fn): - """Reduces data per device. - - This can be useful, for example, if we want to all-reduce n tensors on k100k samples in length and have a width # of 2 or 4. Mono audio has a single channel while stereo has 2. @@ -422,6 +427,7 @@ def bottom(self, inputs): with tf.variable_scope(self.name): # TODO(aidangomez): Will need to sort out a better audio pipeline def xnet_resblock(x, filters, res_relu, name): + """Xception-like block.""" with tf.variable_scope(name): # We only stride along the length dimension to preserve the spectral # bins (which are tiny in dimensionality relative to length) @@ -464,8 +470,15 @@ def bottom(self, inputs): "[batch, time, height, width, channels] but got one " "of shape: %s" % str(inputs_shape)) if not tf.contrib.eager.in_eager_mode(): - tf.summary.image("inputs", tf.cast(inputs[:, -1, :, :, :], tf.uint8), - max_outputs=1) + if inputs.get_shape().as_list()[1] is None: + tf.summary.image( + "inputs_last_frame", tf.cast(inputs[:, -1, :, :, :], tf.uint8), + max_outputs=1) + else: + for k in range(inputs_shape[1]): + tf.summary.image( + "inputs_frame_%d" % k, tf.cast(inputs[:, k, :, :, :], tf.uint8), + max_outputs=1) # Standardize frames. inputs = tf.reshape(inputs, [-1] + inputs_shape[2:]) inputs = common_layers.standardize_images(inputs) @@ -533,9 +546,56 @@ def loss(self, logits, targets): logits, targets, self._model_hparams.label_smoothing, + cutoff=0.001, weights_fn=self.targets_weights_fn) +@registry.register_video_modality("l1") +class VideoModalityL1(VideoModality): + """Video modality that predicts a scalar per channel with an L1 loss.""" + + def top(self, body_output, _): + num_channels = self._model_hparams.problem.num_channels + num_frames = self._model_hparams.problem.num_target_frames + with tf.variable_scope("rgb"): + body_output_shape = common_layers.shape_list(body_output) + res = tf.layers.dense(body_output, num_channels * num_frames, name="cast") + res = tf.reshape(res, body_output_shape[:3] + [num_channels, num_frames]) + res = tf.transpose(res, [0, 4, 1, 2, 3]) # Move frames next to batch. + if not tf.get_variable_scope().reuse: + res_argmax = tf.cast(res[:, -1, :, :, :], tf.uint8) + tf.summary.image("result", res_argmax, max_outputs=1) + return tf.expand_dims(res, axis=-1) # Add an axis like in perplexity. + + @property + def cutoff(self): + return 0.2 + + def internal_loss(self, logits, targets): + return tf.nn.relu(tf.abs(logits - targets) - self.cutoff) + + def loss(self, logits, targets): + """Compute loss numerator and denominator for one shard of output.""" + logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:-1]) + targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:]) + weights = self.targets_weights_fn(targets) + # Shift targets by 0.5 so later just casting to int gives the prediction. + # So for int targets, say 0 and 7, we actually train to predict 0.5 and 7.5. + # Later (in merics or infer) this is cast to int anyway. Also, we have no + # loss beyond self.cutoff = 0.2 as these are already correct predictions. + targets = tf.to_float(targets) + 0.5 + loss = self.internal_loss(logits, targets) + return tf.reduce_sum(loss * weights), tf.reduce_sum(weights) + + +@registry.register_video_modality("l2") +class VideoModalityL2(VideoModalityL1): + """Modality for videos with L2 loss.""" + + def internal_loss(self, logits, targets): + return tf.nn.relu((logits - targets)**2 - self.cutoff * self.cutoff) + + @registry.register_class_label_modality("default") class ClassLabelModality(modality.Modality): """Used for label data.""" diff --git a/tensor2tensor/models/image_transformer.py b/tensor2tensor/models/image_transformer.py index 834a80d34..ed4a25692 100644 --- a/tensor2tensor/models/image_transformer.py +++ b/tensor2tensor/models/image_transformer.py @@ -48,6 +48,8 @@ def body(self, features): hparams.mode == tf.contrib.learn.ModeKeys.INFER): tf.summary.image("targets", targets, max_outputs=1) + # Extra losses list if we want to use moe. + losses = [] # Prepare decoder inputs and bias. decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. @@ -61,9 +63,14 @@ def body(self, features): hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, + losses=losses, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) - return output + + if losses: + return output, {"extra_loss": tf.add_n(losses)} + else: + return output @registry.register_model @@ -171,6 +178,11 @@ def image_transformer_base(): hparams.add_hparam("unconditional", False) # unconditional generation + # These parameters are only used when ffn_layer=="local_moe_tpu" + hparams.add_hparam("moe_overhead_train", 1.0) + hparams.add_hparam("moe_overhead_eval", 2.0) + hparams.moe_num_experts = 8 + hparams.moe_loss_coef = 1e-3 return hparams @@ -678,9 +690,209 @@ def imagetransformer_sep_channels_8l_tpu(): @registry.register_hparams -def imagetransformer_bas8l_8h_big_uncond_dr03_imgnet_tpu(): +def imagetransformer_b10l_4h_big_uncond_dr03_tpu(): hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() update_hparams_for_tpu(hparams) - hparams.batch_size = 1 + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 10 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_dr03_moe_tpu(): + """Moe tpu params.""" + hparams = imagetransformer_b10l_4h_big_uncond_dr03_tpu() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 10 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.ffn_layer = "local_moe_tpu" + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_4h_big_uncond_dr03_lr025_tpu(): + """TPU related small model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 10 + hparams.learning_rate = 0.25 + hparams.learning_rate_warmup_steps = 8000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + # hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_big_uncond_dr03_tpu(): + """TPU 12 layer model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_big_uncond_dr03_lr025_tpu(): + hparams = imagetransformer_b12l_4h_big_uncond_dr03_tpu() + update_hparams_for_tpu(hparams) + hparams.learning_rate = 0.25 + hparams.learning_rate_warmup_steps = 5000 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b256_uncond_dr03_tpu(): + """works very well on 4x4.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.5 + hparams.learning_rate_warmup_steps = 4000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b128_h512_uncond_dr03_tpu(): + """TPU related big model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 6000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b128_h512_uncond_dr03_im(): + """TPU related imagenet model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 6000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_small_uncond_dr03_tpu(): + """TPU related small model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 8 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 4000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b128_uncond_dr03_tpu(): + """TPU config for cifar 10.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 2 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 256 + hparams.filter_size = 2048 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_8h_b256_uncond_dr03_tpu(): + """TPU related 12 layer 8 heads model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 2 hparams.num_heads = 8 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b256_uncond_dr03_lr025_tpu(): + hparams = imagetransformer_b12l_4h_b256_uncond_dr03_tpu() + update_hparams_for_tpu(hparams) + hparams.learning_rate = 0.25 + hparams.learning_rate_warmup_steps = 10000 + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_4h_big_uncond_dr03_lr05_tpu(): + hparams = imagetransformer_b10l_4h_big_uncond_dr03_lr025_tpu() + update_hparams_for_tpu(hparams) + hparams.learning_rate = 0.5 + hparams.learning_rate_warmup_steps = 16000 + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_4h_big_uncond_dr01_tpu(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_b12l_4h_big_uncond_dr03_tpu() + # num_hidden_layers + hparams.num_decoder_layers = 10 + hparams.num_heads = 4 + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + hparams.batch_size = 1 + hparams.layer_prepostprocess_dropout = 0.1 return hparams diff --git a/tensor2tensor/models/research/aligned.py b/tensor2tensor/models/research/aligned.py index ed048a68c..7aa731bd9 100644 --- a/tensor2tensor/models/research/aligned.py +++ b/tensor2tensor/models/research/aligned.py @@ -229,7 +229,7 @@ def _pseudolocal_bias(x): expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes, hparams.hidden_size), dp(expert_utils.flatten_all_but_last, x)) - y = dp(expert_utils.reshape_like, y, x) + y = dp(common_layers.reshape_like, y, x) elif layer_type == "conv": y = dp( common_layers.conv1d, diff --git a/tensor2tensor/models/research/attention_lm_moe.py b/tensor2tensor/models/research/attention_lm_moe.py index 6fd549cbe..dd0163bfb 100644 --- a/tensor2tensor/models/research/attention_lm_moe.py +++ b/tensor2tensor/models/research/attention_lm_moe.py @@ -453,7 +453,7 @@ def restore_pad(x, ref_x, pad_remover, mode): x = tf.squeeze(x, axis=0) if mode != ModeKeys.PREDICT: x = pad_remover.restore(x) - x = expert_utils.reshape_like(x, ref_x) + x = common_layers.reshape_like(x, ref_x) return x diff --git a/tensor2tensor/models/research/autoencoders.py b/tensor2tensor/models/research/autoencoders.py index 4a05af024..aee732e42 100644 --- a/tensor2tensor/models/research/autoencoders.py +++ b/tensor2tensor/models/research/autoencoders.py @@ -54,6 +54,11 @@ def body(self, features): features["targets"], tf.zeros_like(basic_result), hparams.bottleneck_warmup_steps, is_training, max_prob=1.0 - hparams.autoregressive_dropout, broadcast_last=True) + # Sometimes it's useful to look at non-autoregressive evals. + if (hparams.mode == tf.estimator.ModeKeys.EVAL and + hparams.autoregressive_eval_pure_autoencoder): + targets_dropout = tf.zeros_like(basic_result) + # Now combine the basic reconstruction with shifted targets. targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]]) targets_shifted = common_layers.shift_right_3d(targets1d) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) @@ -85,7 +90,7 @@ def body(self, features): raise ValueError("Unsupported autoregressive mode: %s" % hparams.autoregressive_mode) - def infer(self, features=None, *args, **kwargs): + def infer(self, features, *args, **kwargs): """Produce predictions from the model by sampling.""" # Inputs and features preparation needed to handle edge cases. if not features: @@ -109,7 +114,7 @@ def infer(self, features=None, *args, **kwargs): shape = common_layers.shape_list(samples) # Sample again if requested for the autoregressive part. - extra_samples = 0 + extra_samples = self.hparams.autoregressive_decode_steps self.hparams.autoregressive_dropout = 0.2 for i in range(extra_samples): if i == extra_samples - 2: @@ -141,6 +146,17 @@ def infer(self, features=None, *args, **kwargs): class AutoencoderResidual(AutoencoderAutoregressive): """Residual autoencoder.""" + def dropout(self, x): + if self.hparams.dropout <= 0.0: + return x + # For simple dropout just do this: + # return tf.nn.dropout(x, 1.0 - self.hparams.dropout) + is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN + return common_layers.mix( + tf.zeros_like(x), x, + self.hparams.bottleneck_warmup_steps, is_training, + max_prob=self.hparams.dropout, broadcast_last=True) + def encoder(self, x): with tf.variable_scope("encoder"): hparams = self.hparams @@ -156,7 +172,7 @@ def encoder(self, x): for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): x = self.make_even_size(x) - x = tf.nn.dropout(x, 1.0 - hparams.dropout) + x = self.dropout(x) filters = hparams.hidden_size * 2**(i + 1) filters = min(filters, hparams.max_hidden_size) x = tf.layers.conv2d( @@ -189,7 +205,6 @@ def decoder(self, x): residual_conv = tf.layers.separable_conv2d # Up-convolutions. for i in range(hparams.num_hidden_layers): - x = tf.nn.dropout(x, 1.0 - hparams.dropout) j = hparams.num_hidden_layers - i - 1 filters = hparams.hidden_size * 2**j filters = min(filters, hparams.max_hidden_size) @@ -405,6 +420,8 @@ def autoencoder_autoregressive(): hparams.add_hparam("autoregressive_forget_base", False) hparams.add_hparam("autoregressive_mode", "conv3") hparams.add_hparam("autoregressive_dropout", 0.4) + hparams.add_hparam("autoregressive_decode_steps", 0) + hparams.add_hparam("autoregressive_eval_pure_autoencoder", False) return hparams @@ -473,6 +490,7 @@ def autoencoder_residual_discrete_big(): def autoencoder_ordered_discrete(): """Basic autoencoder model.""" hparams = autoencoder_residual_discrete() + hparams.bottleneck_noise = 1.0 return hparams diff --git a/tensor2tensor/models/research/basic_conv_gen.py b/tensor2tensor/models/research/basic_conv_gen.py index 9d4a810bb..77729f0ca 100644 --- a/tensor2tensor/models/research/basic_conv_gen.py +++ b/tensor2tensor/models/research/basic_conv_gen.py @@ -34,33 +34,53 @@ class BasicConvGen(t2t_model.T2TModel): """Basic convolutional next-frame model.""" + def make_even_size(self, x): + """Pad x to be even-sized on axis 1 and 2, but only if necessary.""" + shape = [dim if dim is not None else -1 for dim in x.get_shape().as_list()] + if shape[1] % 2 == 0 and shape[2] % 2 == 0: + return x + if shape[1] % 2 == 0: + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=2) + return x + if shape[2] % 2 == 0: + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=1) + return x + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=1) + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=2) + return x + def body(self, features): hparams = self.hparams filters = hparams.hidden_size kernel1, kernel2 = (3, 3), (4, 4) - # Pad to make size powers of 2 as needed. - x = features["inputs"] - inputs_shape = common_layers.shape_list(x) - x, _ = common_layers.pad_to_same_length( - x, x, final_length_divisible_by=2**hparams.num_compress_steps, axis=1) - x, _ = common_layers.pad_to_same_length( - x, x, final_length_divisible_by=2**hparams.num_compress_steps, axis=2) + # Embed the inputs. + inputs_shape = common_layers.shape_list(features["inputs"]) + x = tf.layers.dense(features["inputs"], filters, name="inputs_embed") # Down-stride. + layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): + layer_inputs.append(x) + x = self.make_even_size(x) + if i < hparams.filter_double_steps: + filters *= 2 x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) - filters *= 2 # Add embedded action. action = tf.reshape(features["input_action"][:, 1, :], [-1, 1, 1, hparams.hidden_size]) - zeros = tf.zeros(common_layers.shape_list(x)[:-1] + [hparams.hidden_size], - dtype=tf.float32) - x = tf.concat([x, action + zeros], axis=-1) + action_mask = tf.layers.dense(action, filters, name="action_mask") + zeros_mask = tf.zeros(common_layers.shape_list(x)[:-1] + [filters], + dtype=tf.float32) + x *= action_mask + zeros_mask # Run a stack of convolutions. for i in range(hparams.num_hidden_layers): @@ -74,14 +94,18 @@ def body(self, features): x = common_layers.layer_norm(x + y) # Up-convolve. + layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): - filters //= 2 + if i >= hparams.num_compress_steps - hparams.filter_double_steps: + filters //= 2 x = tf.layers.conv2d_transpose( x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") - x = common_layers.layer_norm(x) - x = tf.nn.dropout(x, 1.0 - hparams.dropout) + y = layer_inputs[i] + shape = common_layers.shape_list(y) + x = x[:, :shape[1], :shape[2], :] + x = common_layers.layer_norm(x + y) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] @@ -90,7 +114,7 @@ def body(self, features): reward_pred = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) return {"targets": x, "target_reward": reward_pred} - def infer(self, features=None, *args, **kwargs): + def infer(self, features, *args, **kwargs): """Produce predictions from the model by running it.""" # Inputs and features preparation needed to handle edge cases. if not features: @@ -100,6 +124,16 @@ def infer(self, features=None, *args, **kwargs): inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) + def logits_to_samples(logits): + """Get samples from logits.""" + # If the last dimension is 1 then we're using L1/L2 loss. + if common_layers.shape_list(logits)[-1] == 1: + return tf.to_int32(tf.squeeze(logits, axis=-1)) + # Argmax in TF doesn't handle more than 5 dimensions yet. + logits_shape = common_layers.shape_list(logits) + argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=-1) + return tf.reshape(argmax, logits_shape[:-1]) + # Get predictions. try: num_channels = self._hparams.problem.num_channels @@ -113,15 +147,9 @@ def infer(self, features=None, *args, **kwargs): if isinstance(logits, dict): results = {} for k, v in six.iteritems(logits): - # Argmax in TF doesn't handle more than 5 dimensions yet. - v_shape = common_layers.shape_list(v) - argmax = tf.argmax(tf.reshape(v, [-1, v_shape[-1]]), axis=-1) - results[k] = tf.reshape(argmax, v_shape[:-1]) + results[k] = logits_to_samples(v) else: - # Argmax in TF doesn't handle more than 5 dimensions yet. - logits_shape = common_layers.shape_list(logits) - argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=-1) - results = tf.reshape(argmax, logits_shape[:-1]) + results = logits_to_samples(logits) # Restore inputs to not confuse Estimator in edge cases. if inputs_old is not None: @@ -135,19 +163,20 @@ def infer(self, features=None, *args, **kwargs): def basic_conv(): """Basic 2-frame conv model.""" hparams = common_hparams.basic_params1() - hparams.hidden_size = 64 + hparams.hidden_size = 32 hparams.batch_size = 8 - hparams.num_hidden_layers = 3 - hparams.optimizer = "Adam" - hparams.learning_rate_constant = 0.0002 - hparams.learning_rate_warmup_steps = 500 - hparams.learning_rate_schedule = "constant * linear_warmup" + hparams.num_hidden_layers = 2 + hparams.optimizer = "Adafactor" + hparams.learning_rate_constant = 0.5 + hparams.learning_rate_warmup_steps = 1500 + hparams.learning_rate_schedule = "linear_warmup * constant * rsqrt_decay" hparams.label_smoothing = 0.0 hparams.initializer = "uniform_unit_scaling" hparams.initializer_gain = 1.0 hparams.weight_decay = 0.0 hparams.dropout = 0.2 - hparams.add_hparam("num_compress_steps", 5) + hparams.add_hparam("num_compress_steps", 6) + hparams.add_hparam("filter_double_steps", 5) return hparams @@ -160,58 +189,16 @@ def basic_conv_small(): @registry.register_hparams -def basic_conv_small_per_image_standardization(): - """Small conv model.""" - hparams = common_hparams.basic_params1() - hparams.kernel_sizes = [(3, 3), (5, 5)] - hparams.filter_numbers = [32, 3*256] - hparams.batch_size = 2 - hparams.add_hparam("per_image_standardization", True) +def basic_conv_l1(): + """Basic conv model with L1 modality.""" + hparams = basic_conv() + hparams.target_modality = "video:l1" return hparams -@registry.register_model -class MichiganBasicConvGen(t2t_model.T2TModel): - - def body(self, features): - def deconv2d(cur, i, kernel_size, output_filters, activation=tf.nn.relu): - thicker = common_layers.conv( - cur, - output_filters * 4, - kernel_size, - padding="SAME", - activation=activation, - name="deconv2d" + str(i)) - return tf.depth_to_space(thicker, 2) - - cur_frame = common_layers.standardize_images(features["inputs_0"]) - prev_frame = common_layers.standardize_images(features["inputs_1"]) - - frames = tf.concat([cur_frame, prev_frame], axis=3) - frames = tf.reshape(frames, [-1, 210, 160, 6]) - - h1 = tf.layers.conv2d(frames, filters=64, strides=2, kernel_size=(8, 8), - padding="SAME", activation=tf.nn.relu) - h2 = tf.layers.conv2d(h1, filters=128, strides=2, kernel_size=(6, 6), - padding="SAME", activation=tf.nn.relu) - h3 = tf.layers.conv2d(h2, filters=128, strides=2, kernel_size=(6, 6), - padding="SAME", activation=tf.nn.relu) - h4 = tf.layers.conv2d(h3, filters=128, strides=2, kernel_size=(4, 4), - padding="SAME", activation=tf.nn.relu) - h45 = tf.reshape(h4, [-1, 14 * 10 * 128]) - h5 = tf.layers.dense(h45, 2048, activation=tf.nn.relu) - h6 = tf.layers.dense(h5, 2048, activation=tf.nn.relu) - h7 = tf.layers.dense(h6, 14 * 10 * 128, activation=tf.nn.relu) - h8 = tf.reshape(h7, [-1, 14, 10, 128]) - - h9 = deconv2d(h8, 1, (4, 4), 128, activation=tf.nn.relu) - h9 = h9[:, :27, :, :] - h10 = deconv2d(h9, 2, (6, 6), 128, activation=tf.nn.relu) - h10 = h10[:, :53, :, :] - h11 = deconv2d(h10, 3, (6, 6), 128, activation=tf.nn.relu) - h11 = h11[:, :105, :, :] - h12 = deconv2d(h11, 4, (8, 8), 3 * 256, activation=tf.identity) - - reward = tf.layers.flatten(h12) - - return {"targets": h12, "reward": reward} +@registry.register_hparams +def basic_conv_l2(): + """Basic conv model with L2 modality.""" + hparams = basic_conv() + hparams.target_modality = "video:l2" + return hparams diff --git a/tensor2tensor/models/research/lm_experiments.py b/tensor2tensor/models/research/lm_experiments.py index 10f2b943a..6e5b11094 100644 --- a/tensor2tensor/models/research/lm_experiments.py +++ b/tensor2tensor/models/research/lm_experiments.py @@ -115,3 +115,44 @@ def lmx_relative_nopos(): hparams = lmx_relative() hparams.pos = "none" return hparams + + +@registry.register_hparams +def lmx_moe(): + """Transformer with mixture of experts. 140M Params.""" + hparams = lmx_base() + hparams.ffn_layer = "local_moe_tpu" + return hparams + + +@registry.register_hparams +def lmx_moe_h1k_f4k_x32(): + """Transformer with mixture of experts. 890M Params.""" + hparams = lmx_h1k_f4k() + hparams.ffn_layer = "local_moe_tpu" + hparams.moe_num_experts = 32 + hparams.weight_dtype = "bfloat16" + hparams.batch_size = 8192 + return hparams + + +@registry.register_hparams +def lmx_moe_h1k_f8k_x16(): + """Transformer with mixture of experts. 890M Params.""" + hparams = lmx_h1k_f4k() + hparams.filter_size = 8192 + hparams.ffn_layer = "local_moe_tpu" + hparams.moe_num_experts = 16 + hparams.weight_dtype = "bfloat16" + hparams.batch_size = 8192 + return hparams + + +@registry.register_hparams +def lmx_h1k_f64k(): + """HParams for training languagemodel_lm1b32k_packed. 880M Params.""" + hparams = lmx_base() + hparams.hidden_size = 1024 + hparams.filter_size = 65536 + hparams.batch_size = 2048 + return hparams diff --git a/tensor2tensor/models/research/r_transformer.py b/tensor2tensor/models/research/r_transformer.py index 9af75e8fc..4631a1441 100644 --- a/tensor2tensor/models/research/r_transformer.py +++ b/tensor2tensor/models/research/r_transformer.py @@ -155,6 +155,10 @@ def body(self, features): Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams + if hparams.add_position_timing_signal: + # Turning off addition of positional embedding in the encoder/decoder + # preparation as we do it in the beginning of each step. + hparams.pos = None if self.has_input: inputs = features["inputs"] @@ -203,7 +207,7 @@ def body(self, features): hparams.act_loss_weight * tf.reduce_mean(dec_ponder_times + dec_remainders)) act_loss = enc_act_loss + dec_act_loss - tf.summary.scalar("act_loss", act_loss) + tf.contrib.summary.scalar("act_loss", act_loss) return decoder_output, {"act_loss": act_loss} return decoder_output @@ -278,7 +282,7 @@ def body(self, features): ponder_times, remainders = enc_extra_output act_loss = hparams.act_loss_weight * tf.reduce_mean(ponder_times + remainders) - tf.summary.scalar("act_loss", act_loss) + tf.contrib.summary.scalar("act_loss", act_loss) return encoder_output, {"act_loss": act_loss} return encoder_output @@ -302,6 +306,17 @@ def update_hparams_for_r_transformer(hparams): # Number of steps (which is equivalent to num layer in transformer). hparams.add_hparam("num_rec_steps", hparams.num_hidden_layers) + # Add the positional mebedding at each step(horisontal timing) + hparams.add_hparam("add_position_timing_signal", False) + # Logic of position shifting when using timing signal: + # None, "random", "step" + hparams.add_hparam("position_start_index", None) + + # Add an step embedding at each step (vertical timing) + hparams.add_hparam("add_step_timing_signal", False) + # Either "learned" or "sinusoid" + hparams.add_hparam("step_timing_signal_type", "leaned") + # Default ffn layer is separable convolution. hparams.add_hparam("transformer_ffn_type", "sep") @@ -655,3 +670,183 @@ def r_transformer_lstm_base(): hparams = r_transformer_base() hparams.recurrence_type = "lstm" return hparams + + +@registry.register_hparams +def r_transformer_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.pos = None + hparams.add_position_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_position_random_timing_base(): + hparams = r_transformer_base() + hparams.pos = None + hparams.add_position_timing_signal = True + hparams.position_start_index = "random" + return hparams + + +@registry.register_hparams +def r_transformer_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.pos = None + hparams.add_position_timing_signal = True + hparams.position_start_index = "random" + return hparams + + +@registry.register_hparams +def r_transformer_position_step_timing_tiny(): + hparams = r_transformer_tiny() + hparams.pos = None + hparams.add_position_timing_signal = True + hparams.position_start_index = "step" + return hparams + + +@registry.register_hparams +def r_transformer_step_timing_tiny(): + hparams = r_transformer_tiny() + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_step_sinusoid_timing_tiny(): + hparams = r_transformer_tiny() + hparams.add_step_timing_signal = True + hparams.step_timing_signal_type = "sinusoid" + return hparams + + +@registry.register_hparams +def r_transformer_step_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + return hparams + + +@registry.register_hparams +def r_transformer_act_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + return hparams + + +@registry.register_hparams +def r_transformer_act_position_step_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "step" + return hparams + + +@registry.register_hparams +def r_transformer_act_step_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_step_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_step_sinusoid_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + hparams.add_step_timing_signal = True + hparams.step_timing_signal_type = "sinusoid" + return hparams + + +@registry.register_hparams +def r_transformer_act_step_sinusoid_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + hparams.step_timing_signal_type = "sinusoid" + return hparams + + +@registry.register_hparams +def r_transformer_act_step_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_step_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_step_position_random_timing_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_step_position_timing_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_step_position_timing_base(): + hparams = r_transformer_base() + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + return hparams diff --git a/tensor2tensor/models/research/r_transformer_util.py b/tensor2tensor/models/research/r_transformer_util.py index 5ac242852..bb6028414 100644 --- a/tensor2tensor/models/research/r_transformer_util.py +++ b/tensor2tensor/models/research/r_transformer_util.py @@ -238,7 +238,10 @@ def get_rt_layer(x, hparams, ffn_unit, attention_unit, pad_remover=None): if hparams.recurrence_type == "basic": rt_initializer = (x, x, x) # (state, input, memory) rt_function = functools.partial( - r_transformer_basic, ffn_unit=ffn_unit, attention_unit=attention_unit) + r_transformer_basic, + hparams=hparams, + ffn_unit=ffn_unit, + attention_unit=attention_unit) elif hparams.recurrence_type == "highway": rt_initializer = (x, x, x) # (state, input, memory) @@ -499,7 +502,7 @@ def transformer_decoder_attention_unit(x, return x -def r_transformer_basic(layer_inputs, unused_step, ffn_unit, attention_unit): +def r_transformer_basic(layer_inputs, step, hparams, ffn_unit, attention_unit): """Basic r_transformer. This is in fact vanilla transformer in which weights are shared between @@ -510,7 +513,8 @@ def r_transformer_basic(layer_inputs, unused_step, ffn_unit, attention_unit): Args: layer_inputs: - state: state - unused_step: indicating number of steps take so far + step: indicating number of steps take so far + hparams: model hyper-parameters ffn_unit: feed-forward unit attention_unit: multi-head attention unit @@ -519,6 +523,7 @@ def r_transformer_basic(layer_inputs, unused_step, ffn_unit, attention_unit): new_state: new state """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) new_state = ffn_unit(attention_unit(state)) @@ -526,7 +531,7 @@ def r_transformer_basic(layer_inputs, unused_step, ffn_unit, attention_unit): def r_transformer_highway(layer_inputs, - unused_step, + step, hparams, ffn_unit, attention_unit, @@ -546,7 +551,7 @@ def r_transformer_highway(layer_inputs, layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. ffn_unit: feed-forward unit attention_unit: multi-head attention unit @@ -560,6 +565,7 @@ def r_transformer_highway(layer_inputs, """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) transformed_state = ffn_unit(attention_unit(state)) state.get_shape().assert_is_compatible_with(state.get_shape()) @@ -614,7 +620,7 @@ def r_transformer_highway(layer_inputs, def r_transformer_skip(layer_inputs, - unused_step, + step, hparams, ffn_unit, attention_unit, @@ -634,7 +640,7 @@ def r_transformer_skip(layer_inputs, layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. ffn_unit: feed-forward unit attention_unit: multi-head attention unit @@ -648,6 +654,7 @@ def r_transformer_skip(layer_inputs, """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) transformed_state = ffn_unit(attention_unit(state)) @@ -744,6 +751,9 @@ def r_transformer_depthwise_attention(layer_inputs, step, hparams, ffn_unit, state_to_be_transformed = tf.reduce_sum( (states_so_far * states_so_far_weights), axis=0) + state_to_be_transformed = step_preprocess(state_to_be_transformed, step, + hparams) + new_state = ffn_unit(attention_unit(state_to_be_transformed)) # add the new state to the memory @@ -753,12 +763,12 @@ def r_transformer_depthwise_attention(layer_inputs, step, hparams, ffn_unit, def r_transformer_rnn(layer_inputs, - unused_step, + step, hparams, ffn_unit, attention_unit, pad_remover=None): - """The RT cell which models recurencey similar to basic RNN cell. + """The RT layer which models recurencey similar to basic RNN cell. It's an R-transformer with an RNN applied over the stats on depth. @@ -766,7 +776,7 @@ def r_transformer_rnn(layer_inputs, layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. ffn_unit: feed-forward unit attention_unit: multi-head attention unit @@ -783,6 +793,7 @@ def r_transformer_rnn(layer_inputs, """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) # TODO(dehghani) keep only the meaningful cases: if hparams.inputs_states_combination == "mh_attention_ffn_add": @@ -824,11 +835,11 @@ def r_transformer_rnn(layer_inputs, def r_transformer_gru(layer_inputs, - unused_step, + step, hparams, attention_unit, pad_remover=None): - """The RT cell which models recurencey similar to GRU cell. + """The RT layer which models recurencey similar to GRU cell. It's an R-transformer with a gru applied over the stats on depth. Based on GRU paper: http://arxiv.org/abs/1406.1078 @@ -837,7 +848,7 @@ def r_transformer_gru(layer_inputs, layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. attention_unit: multi-head attention unit pad_remover: to mask out padding in convolutional layers (efficiency). @@ -850,6 +861,7 @@ def r_transformer_gru(layer_inputs, """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) # TODO(dehghani): do we need preprocess here? state = common_layers.layer_preprocess(state, hparams) @@ -898,11 +910,11 @@ def r_transformer_gru(layer_inputs, def r_transformer_lstm(layer_inputs, - unused_step, + step, hparams, attention_unit, pad_remover=None): - """The RT cell which models recurencey similar to GRU cell. + """The RT layer which models recurencey similar to GRU cell. It's an R-transformer with a gru applied over the stats on depth. based on LSTM paper: https://arxiv.org/pdf/1409.2329.pdf @@ -912,7 +924,7 @@ def r_transformer_lstm(layer_inputs, - state: state - inputs: the original embedded inputs (= inputs to the first step) - memory: memory used in lstm. - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. attention_unit: multi-head attention unit pad_remover: to mask out padding in convolutional layers (efficiency). @@ -924,6 +936,7 @@ def r_transformer_lstm(layer_inputs, memory: contains states from all the previous steps. """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) state = common_layers.layer_preprocess(state, hparams) inputs = common_layers.layer_preprocess(inputs, hparams) @@ -1079,6 +1092,7 @@ def rt_function(state, step, halting_probability, remainders, n_updates, new_state: new state """ state_shape = state.get_shape() + state = step_preprocess(state, step, hparams) with tf.variable_scope("sigmoid_activation_for_pondering"): p = common_layers.dense( @@ -1159,13 +1173,13 @@ def should_continue(u0, u1, halting_probability, u2, n_updates, u3): ponder_times = n_updates remainders = remainder - tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + tf.contrib.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) return new_state, (ponder_times, remainders) def r_transformer_act_accumulated(x, hparams, ffn_unit, attention_unit): - """The RTAct cell where the final state is accumulation of all states. + """The RTAct layer where the final state is accumulation of all states. (similar to the main ACT paper: --> check the issue of differentiability) @@ -1230,6 +1244,7 @@ def rt_function(state, step, halting_probability, remainders, n_updates, accumulated_state: accumulated state """ state_shape = state.get_shape() + state = step_preprocess(state, step, hparams) with tf.variable_scope("sigmoid_activation_for_pondering"): p = common_layers.dense( @@ -1309,7 +1324,7 @@ def should_continue(u0, u1, halting_probability, u2, n_updates, u3): ponder_times = n_updates remainders = remainder - tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + tf.contrib.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) return accumulated_state, (ponder_times, remainders) @@ -1366,6 +1381,8 @@ def rt_function(state, step, halting_probability, remainders, n_updates, """ + state = step_preprocess(state, step, hparams) + with tf.variable_scope("sigmoid_activation_for_pondering"): p = common_layers.dense( state, @@ -1449,7 +1466,7 @@ def should_continue(u0, u1, halting_probability, u2, n_updates, u3): ponder_times = n_updates remainders = remainder - tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + tf.contrib.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) return new_state, (ponder_times, remainders) @@ -1520,6 +1537,7 @@ def rt_function(state, step, halting_probability, remainders, n_updates, """ state_shape = state.get_shape() + state = step_preprocess(state, step, hparams) # random as halting probability p = tf.random_uniform(shape=common_layers.shape_list(halting_probability)) @@ -1595,7 +1613,7 @@ def should_continue(u0, u1, halting_probability, u2, n_updates, u3): ponder_times = n_updates remainders = remainder - tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + tf.contrib.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) return new_state, (ponder_times, remainders) @@ -1749,12 +1767,78 @@ def add_depth_embedding(x): x_shape = common_layers.shape_list(x) depth = x_shape[-1] num_steps = x_shape[0] - shape = [num_steps, 1, 1, x_shape[-1]] + shape = [num_steps, 1, 1, depth] depth_embedding = ( tf.get_variable( "depth_embedding", shape, initializer=tf.random_normal_initializer(0, depth**-0.5)) * (depth** 0.5)) + x += depth_embedding return x + + +def step_preprocess(x, step, hparams): + """Preprocess the input at the beginning of each step. + + Args: + x: input tensor + step: step + hparams: model hyper-parameters + + Returns: + preprocessed input. + + """ + if hparams.add_position_timing_signal: + x = add_position_timing_signal(x, step, hparams) + + if hparams.add_step_timing_signal: + num_steps = ( + hparams.act_max_steps + if hparams.recurrence_type == "act" else hparams.num_rec_steps) + if hparams.step_timing_signal_type == "leaned": + x = common_attention.add_layer_timing_signal_learned_1d( + x, step, num_steps) + elif hparams.step_timing_signal_type == "sinusoid": + x = common_attention.add_layer_timing_signal_sinusoid_1d( + x, step, num_steps) + return x + + +def add_position_timing_signal(x, step, hparams): + """Add n-dimensional embedding as the position (horizontal) timing signal. + + Args: + x: a tensor with shape [batch, length, depth] + step: step + hparams: model hyper parameters + + Returns: + a Tensor the same shape as x. + + """ + + if not hparams.position_start_index: + index = 0 + + elif hparams.position_start_index == "random": + # Shift all positions randomly + # TODO(dehghani): What would be reasonable for max number of shift? + index = tf.random_uniform( + [], maxval=common_layers.shape_list(x)[1], dtype=tf.int32) + + elif hparams.position_start_index == "step": + # Shift positions based on the step + num_steps = ( + hparams.act_max_steps + if hparams.recurrence_type == "act" else hparams.num_rec_steps) + index = tf.cast( + common_layers.shape_list(x)[1] * step / num_steps, dtype=tf.int32) + + # No need for the timing signal in the encoder/decoder input preparation + assert hparams.pos is None + + x_with_timing = common_attention.add_timing_signal_1d(x, start_index=index) + return x_with_timing diff --git a/tensor2tensor/models/research/rl.py b/tensor2tensor/models/research/rl.py index 0c39aa8eb..6feeda60c 100644 --- a/tensor2tensor/models/research/rl.py +++ b/tensor2tensor/models/research/rl.py @@ -182,12 +182,13 @@ def feed_forward_cnn_small_categorical_fun(action_space, config, observations): return NetworkOutput(policy, value, lambda a: a) -def random_policy_fun(action_space, config, observations): - """random policy with categorical output""" - obs_shape = observations.shape.as_list() +def random_policy_fun(action_space, unused_config, observations): + """Random policy with categorical output.""" + obs_shape = observations.shape.as_list() with tf.variable_scope("network_parameters"): value = tf.zeros(obs_shape[:2]) - policy = tf.distributions.Categorical(probs=[[[1. / float(action_space.n)]*action_space.n]*(obs_shape[0]*obs_shape[1])]) - + policy = tf.distributions.Categorical( + probs=[[[1. / float(action_space.n)] * action_space.n + ] * (obs_shape[0] * obs_shape[1])]) return NetworkOutput(policy, value, lambda a: a) diff --git a/tensor2tensor/models/research/super_lm.py b/tensor2tensor/models/research/super_lm.py index 5ffe5e256..adcf5f3ad 100644 --- a/tensor2tensor/models/research/super_lm.py +++ b/tensor2tensor/models/research/super_lm.py @@ -100,7 +100,7 @@ def body(self, features): # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits") - logits = common_layers.all_reduce_ring(logits, mp) + logits = expert_utils.all_reduce_ring(logits, mp) logits = mp(tf.multiply, logits, mp.n ** -0.5) # We now have identical logits on all shards. # Shard 0 gets returned to the estimator. @@ -109,7 +109,7 @@ def body(self, features): logits_shard_0 = tf.expand_dims(logits_shard_0, 3) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. - mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0]) + mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): if mp.n > 1: logits = common_layers.approximate_split(logits, mp.n, 0)[shard] @@ -180,7 +180,7 @@ def _split(t): return tuple(tf.split( t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) - mixed = common_layers.all_reduce_ring(to_mix, mp) + mixed = expert_utils.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n ** -0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": diff --git a/tensor2tensor/models/research/transformer_symshard.py b/tensor2tensor/models/research/transformer_symshard.py index 0acc4251e..fe0741550 100644 --- a/tensor2tensor/models/research/transformer_symshard.py +++ b/tensor2tensor/models/research/transformer_symshard.py @@ -207,10 +207,10 @@ def body(self, features): else: logits = mp( tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) - logits = common_layers.all_reduce_ring(logits, mp) + logits = expert_utils.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. - mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0]) + mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] @@ -282,7 +282,7 @@ def _split(t): return tuple(tf.split( t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) - mixed = common_layers.all_reduce_ring(to_mix, mp) + mixed = expert_utils.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n ** -0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": diff --git a/tensor2tensor/models/research/transformer_vae.py b/tensor2tensor/models/research/transformer_vae.py index 9825d8290..b1cca4792 100644 --- a/tensor2tensor/models/research/transformer_vae.py +++ b/tensor2tensor/models/research/transformer_vae.py @@ -227,8 +227,14 @@ def ae_latent_softmax(latents_pred, latents_discrete, hparams): name="extra_logits") loss = None if latents_discrete is not None: - loss = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=latents_discrete, logits=latents_logits) + if hparams.soft_em: + # latents_discrete is actually one-hot of multinomial samples + assert hparams.num_decode_blocks == 1 + loss = tf.nn.softmax_cross_entropy_with_logits( + labels=latents_discrete, logits=latents_logits) + else: + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=latents_discrete, logits=latents_logits) sample = multinomial_sample( latents_logits, vocab_size, hparams.sampling_temp) return sample, loss diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 95aa91ff1..137829195 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs): super(Transformer, self).__init__(*args, **kwargs) self.attention_weights = dict() # For visualizing attention heads. - def encode(self, inputs, target_space, hparams, features=None): + def encode(self, inputs, target_space, hparams, features=None, losses=None): """Encode transformer inputs. Args: @@ -62,6 +62,7 @@ def encode(self, inputs, target_space, hparams, features=None): hparams: hyperparameters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. + losses: optional list onto which to append extra training losses Returns: Tuple of: @@ -82,7 +83,8 @@ def encode(self, inputs, target_space, hparams, features=None): encoder_output = transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=features_to_nonpadding(features, "inputs"), - save_weights_to=self.attention_weights) + save_weights_to=self.attention_weights, + losses=losses) return encoder_output, encoder_decoder_attention_bias @@ -93,7 +95,8 @@ def decode(self, decoder_self_attention_bias, hparams, cache=None, - nonpadding=None): + nonpadding=None, + losses=None): """Decode Transformer outputs from encoder representation. Args: @@ -109,6 +112,7 @@ def decode(self, cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. nonpadding: optional Tensor with shape [batch_size, decoder_length] + losses: optional list onto which to append extra training losses Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] @@ -124,7 +128,8 @@ def decode(self, hparams, cache=cache, nonpadding=nonpadding, - save_weights_to=self.attention_weights) + save_weights_to=self.attention_weights, + losses=losses) if (common_layers.is_on_tpu() and hparams.mode == tf.estimator.ModeKeys.TRAIN): @@ -150,11 +155,13 @@ def body(self, features): """ hparams = self._hparams + losses = [] + if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( - inputs, target_space, hparams, features=features) + inputs, target_space, hparams, features=features, losses=losses) else: encoder_output, encoder_decoder_attention_bias = (None, None) @@ -171,7 +178,8 @@ def body(self, features): encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, - nonpadding=features_to_nonpadding(features, "targets")) + nonpadding=features_to_nonpadding(features, "targets"), + losses=losses) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: @@ -181,7 +189,11 @@ def body(self, features): hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} - return tf.reshape(decoder_output, targets_shape) + ret = tf.reshape(decoder_output, targets_shape) + if losses: + return ret, {"extra_loss": tf.add_n(losses)} + else: + return ret def _greedy_infer(self, features, decode_length): """Fast version of greedy decoding. @@ -278,10 +290,7 @@ def _fast_decode(self, if target_modality.is_class_modality: decode_length = 1 else: - if 'decode_length' in features: - decode_length = common_layers.shape_list(inputs)[1] + features['decode_length'] - else: - decode_length = common_layers.shape_list(inputs)[1] + decode_length + decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) @@ -318,10 +327,7 @@ def _fast_decode(self, partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] - if 'decode_length' in features: - decode_length = partial_targets_length + features['decode_length'] - else: - decode_length += partial_targets_length + decode_length += partial_targets_length batch_size = partial_targets_shape[0] if hparams.pos == "timing": @@ -395,7 +401,7 @@ def forced_logits(): ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache - force_decode_length=self._decode_hparams.force_decode_length + ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, @@ -406,8 +412,7 @@ def forced_logits(): beam_size=beam_size, top_beams=top_beams, alpha=alpha, - batch_size=batch_size, - force_decode_length=force_decode_length) + batch_size=batch_size) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] @@ -426,8 +431,7 @@ def fast_decode(encoder_output, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID, - batch_size=None, - force_decode_length=False): + batch_size=None): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff @@ -448,7 +452,6 @@ def fast_decode(encoder_output, the preference for longer translations. eos_id: End-of-sequence symbol in beam search. batch_size: an integer scalar - must be passed if there is no input - force_decode_length:force_decode_length: if True, decode will be of length decode_length and will not stop at eos_id. Returns: A dict of decoding results { @@ -473,8 +476,8 @@ def fast_decode(encoder_output, "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), - } - for layer in range(num_layers) + "f": tf.zeros([batch_size, 0, hparams.hidden_size]), + } for layer in range(num_layers) } if encoder_output is not None: @@ -520,10 +523,7 @@ def inner_loop(i, finished, next_id, decoded_ids, cache, log_prob): return i + 1, finished, next_id, decoded_ids, cache, log_prob def is_not_finished(i, finished, *_): - if force_decode_length: - return i < decode_length - else: - return (i < decode_length) & tf.logical_not(tf.reduce_all(finished)) + return (i < decode_length) & tf.logical_not(tf.reduce_all(finished)) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) finished = tf.fill([batch_size], False) @@ -747,7 +747,8 @@ def transformer_encoder(encoder_input, name="encoder", nonpadding=None, save_weights_to=None, - make_image_summary=True): + make_image_summary=True, + losses=None): """A stack of transformer layers. Args: @@ -766,6 +767,7 @@ def transformer_encoder(encoder_input, for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. + losses: optional list onto which to append extra training losses Returns: y: a Tensors @@ -807,7 +809,8 @@ def transformer_encoder(encoder_input, with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover, - conv_padding="SAME", nonpadding_mask=nonpadding) + conv_padding="SAME", nonpadding_mask=nonpadding, + losses=losses) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of @@ -824,7 +827,8 @@ def transformer_decoder(decoder_input, name="decoder", nonpadding=None, save_weights_to=None, - make_image_summary=True): + make_image_summary=True, + losses=None): """A stack of transformer layers. Args: @@ -847,6 +851,7 @@ def transformer_decoder(decoder_input, for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. + losses: optional list onto which to append extra training losses Returns: y: a Tensors @@ -898,8 +903,12 @@ def transformer_decoder(decoder_input, x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( - common_layers.layer_preprocess(x, hparams), hparams, - conv_padding="LEFT", nonpadding_mask=nonpadding) + common_layers.layer_preprocess(x, hparams), + hparams, + conv_padding="LEFT", + nonpadding_mask=nonpadding, + losses=losses, + cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of @@ -911,7 +920,9 @@ def transformer_ffn_layer(x, hparams, pad_remover=None, conv_padding="LEFT", - nonpadding_mask=None): + nonpadding_mask=None, + losses=None, + cache=None): """Feed-forward layer in the transformer. Args: @@ -925,9 +936,16 @@ def transformer_ffn_layer(x, nonpadding_mask: an optional Tensor with shape [batch_size, length]. needed for convolutional layers with "SAME" padding. Contains 1.0 in positions corresponding to nonpadding. + losses: optional list onto which to append extra training losses + cache: dict, containing tensors which are the results of previous + attentions, used for fast decoding. + Returns: a Tensor of shape [batch_size, length, hparams.hidden_size] + + Raises: + ValueError: If losses arg is None, but layer generates extra losses. """ ffn_layer = hparams.ffn_layer relu_dropout_broadcast_dims = ( @@ -959,11 +977,12 @@ def transformer_ffn_layer(x, x, hparams.filter_size, hparams.hidden_size, - first_kernel_size=3, + first_kernel_size=hparams.conv_first_kernel, second_kernel_size=1, padding=conv_padding, nonpadding_mask=nonpadding_mask, - dropout=hparams.relu_dropout) + dropout=hparams.relu_dropout, + cache=cache) elif ffn_layer == "parameter_attention": return common_attention.parameter_attention( x, hparams.parameter_attention_key_channels or hparams.hidden_size, @@ -979,6 +998,23 @@ def transformer_ffn_layer(x, second_kernel_size=(31, 1), padding="LEFT", dropout=hparams.relu_dropout) + elif ffn_layer == "sru": + return common_layers.sru(x) + elif ffn_layer == "local_moe_tpu": + overhead = (hparams.moe_overhead_train + if hparams.mode == tf.estimator.ModeKeys.TRAIN + else hparams.moe_overhead_eval) + ret, loss = expert_utils.local_moe_tpu( + x, hparams.filter_size // 2, + hparams.hidden_size, + hparams.moe_num_experts, overhead=overhead, + loss_coef=hparams.moe_loss_coef) + if losses is None: + raise ValueError( + "transformer_ffn_layer with type local_moe_tpu must pass in " + "a losses list") + losses.append(loss) + return ret else: assert ffn_layer == "none" return x @@ -1033,6 +1069,12 @@ def transformer_base_v1(): hparams.add_hparam("use_pad_remover", True) hparams.add_hparam("self_attention_type", "dot_product") hparams.add_hparam("max_relative_position", 0) + hparams.add_hparam("conv_first_kernel", 3) + # These parameters are only used when ffn_layer=="local_moe_tpu" + hparams.add_hparam("moe_overhead_train", 1.0) + hparams.add_hparam("moe_overhead_eval", 2.0) + hparams.moe_num_experts = 16 + hparams.moe_loss_coef = 1e-3 return hparams diff --git a/tensor2tensor/notebooks/hello_t2t.ipynb b/tensor2tensor/notebooks/hello_t2t.ipynb index 820744500..b7661fcfc 100644 --- a/tensor2tensor/notebooks/hello_t2t.ipynb +++ b/tensor2tensor/notebooks/hello_t2t.ipynb @@ -3,7 +3,7 @@ "nbformat_minor": 0, "metadata": { "colab": { - "name": "T2T with TF Eager", + "name": "Tensor2Tensor Intro", "version": "0.3.2", "views": {}, "default_view": {}, @@ -17,6 +17,18 @@ } }, "cells": [ + { + "metadata": { + "id": "odi2vIMHC3Rm", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Welcome to the [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) Colab\n", + "\n", + "Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and [accelerate ML research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). T2T is actively used and maintained by researchers and engineers within the [Google Brain team](https://research.google.com/teams/brain/) and a community of users. This colab shows you some datasets we have in T2T, how to download and use them, some models we have, how to download pre-trained models and use them, and how to create and train your own models." + ] + }, { "metadata": { "id": "s19ucTii_wYb", @@ -26,9 +38,12 @@ "startup": false, "wait_interval": 0 } - } + }, + "cellView": "form" }, + "cell_type": "code", "source": [ + "#@title\n", "# Copyright 2018 Google LLC.\n", "\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", @@ -43,7 +58,6 @@ "# See the License for the specific language governing permissions and\n", "# limitations under the License." ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -58,11 +72,11 @@ } } }, + "cell_type": "code", "source": [ "# Install deps\n", "!pip install -q -U tensor2tensor tensorflow" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -77,7 +91,9 @@ } } }, + "cell_type": "code", "source": [ + "# Imports we need.\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -111,7 +127,6 @@ "gs_data_dir = \"gs://tensor2tensor-data\"\n", "gs_ckpt_dir = \"gs://tensor2tensor-checkpoints/\"" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -120,10 +135,10 @@ "id": "0a69r1KDiZDe", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "# Download MNIST and inspect it" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -134,11 +149,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - } - ], "base_uri": "https://localhost:8080/", "height": 1241 }, @@ -155,6 +165,7 @@ } } }, + "cell_type": "code", "source": [ "# A Problem is a dataset together with some fixed pre-processing.\n", "# It could be a translation dataset with a specific tokenization,\n", @@ -163,8 +174,7 @@ "# There are many problems available in Tensor2Tensor\n", "problems.available()" ], - "cell_type": "code", - "execution_count": 4, + "execution_count": 0, "outputs": [ { "output_type": "execute_result", @@ -260,11 +270,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 12 - } - ], "base_uri": "https://localhost:8080/", "height": 306 }, @@ -281,6 +286,7 @@ } } }, + "cell_type": "code", "source": [ "# Fetch the MNIST problem\n", "mnist_problem = problems.problem(\"image_mnist\")\n", @@ -288,8 +294,7 @@ "# a standard format ready for training and evaluation.\n", "mnist_problem.generate_data(data_dir, tmp_dir)" ], - "cell_type": "code", - "execution_count": 5, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -325,14 +330,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - }, - { - "item_id": 2 - } - ], "base_uri": "https://localhost:8080/", "height": 381 }, @@ -349,6 +346,7 @@ } } }, + "cell_type": "code", "source": [ "# Now let's see the training MNIST data as Tensors.\n", "mnist_example = tfe.Iterator(mnist_problem.dataset(Modes.TRAIN, data_dir)).next()\n", @@ -358,8 +356,7 @@ "plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap('gray'))\n", "print(\"Label: %d\" % label.numpy())" ], - "cell_type": "code", - "execution_count": 6, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -388,10 +385,10 @@ "id": "gXL7_bVH49Kl", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "# Translate from English to German with a pre-trained model" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -402,11 +399,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 3 - } - ], "base_uri": "https://localhost:8080/", "height": 170 }, @@ -423,6 +415,7 @@ } } }, + "cell_type": "code", "source": [ "# Fetch the problem\n", "ende_problem = problems.problem(\"translate_ende_wmt32k\")\n", @@ -449,8 +442,7 @@ " integers = integers[:integers.index(1)]\n", " return encoders[\"inputs\"].decode(np.squeeze(integers))" ], - "cell_type": "code", - "execution_count": 7, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -480,6 +472,7 @@ } } }, + "cell_type": "code", "source": [ "# # Generate and view the data\n", "# # This cell is commented out because WMT data generation can take hours\n", @@ -504,7 +497,6 @@ "# print(\"Targets, decoded:\")\n", "# print(decode(targets))" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -517,11 +509,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - } - ], "base_uri": "https://localhost:8080/", "height": 408 }, @@ -538,12 +525,12 @@ } } }, + "cell_type": "code", "source": [ "# There are many models available in Tensor2Tensor\n", "registry.list_models()" ], - "cell_type": "code", - "execution_count": 9, + "execution_count": 0, "outputs": [ { "output_type": "execute_result", @@ -592,6 +579,7 @@ } } }, + "cell_type": "code", "source": [ "# Create hparams and the model\n", "model_name = \"transformer\"\n", @@ -604,7 +592,6 @@ "# that will not match the checkpoint.\n", "translate_model = registry.model(model_name)(hparams, Modes.EVAL)" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -617,11 +604,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - } - ], "base_uri": "https://localhost:8080/", "height": 34 }, @@ -638,6 +620,7 @@ } } }, + "cell_type": "code", "source": [ "# Copy the pretrained checkpoint locally\n", "ckpt_name = \"transformer_ende_test\"\n", @@ -646,8 +629,7 @@ "ckpt_path = tf.train.latest_checkpoint(os.path.join(checkpoint_dir, ckpt_name))\n", "ckpt_path" ], - "cell_type": "code", - "execution_count": 11, + "execution_count": 0, "outputs": [ { "output_type": "execute_result", @@ -672,11 +654,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 2 - } - ], "base_uri": "https://localhost:8080/", "height": 68 }, @@ -693,6 +670,7 @@ } } }, + "cell_type": "code", "source": [ "# Restore and translate!\n", "def translate(inputs):\n", @@ -707,8 +685,7 @@ "print(\"Inputs: %s\" % inputs)\n", "print(\"Outputs: %s\" % outputs)" ], - "cell_type": "code", - "execution_count": 13, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -726,10 +703,10 @@ "id": "X3mkIEcbfiTP", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## Attention Viz Utils" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -742,6 +719,7 @@ } } }, + "cell_type": "code", "source": [ "from tensor2tensor.visualization import attention\n", "from tensor2tensor.data_generators import text_encoder\n", @@ -796,7 +774,6 @@ " tokens.append(subtokenizer._subtoken_id_to_subtoken_string(_id))\n", " return tokens" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -811,6 +788,7 @@ } } }, + "cell_type": "code", "source": [ "def call_html():\n", " import IPython\n", @@ -827,7 +805,6 @@ " \n", " '''))" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -836,10 +813,10 @@ "id": "T7UJzFf6fmhp", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## Display Attention" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -850,23 +827,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - }, - { - "item_id": 2 - }, - { - "item_id": 3 - }, - { - "item_id": 4 - }, - { - "item_id": 5 - } - ], "resources": { "http://localhost:8080/static/components/requirejs/require.js": { "data": "/** vim: et:ts=4:sw=4:sts=4
 * @license RequireJS 2.1.22 Copyright (c) 2010-2015, The Dojo Foundation All Rights Reserved.
 * Available via the MIT or new BSD license.
 * see: http://github.com/jrburke/requirejs for details
 */
//Not using strict: uneven strict support in browsers, #392, and causes
//problems with requirejs.exec()/transpiler plugins that may not be strict.
/*jslint regexp: true, nomen: true, sloppy: true */
/*global window, navigator, document, importScripts, setTimeout, opera */

var requirejs, require, define;
(function (global) {
    var req, s, head, baseElement, dataMain, src,
        interactiveScript, currentlyAddingScript, mainScript, subPath,
        version = '2.1.22',
        commentRegExp = /(\/\*([\s\S]*?)\*\/|([^:]|^)\/\/(.*)$)/mg,
        cjsRequireRegExp = /[^.]\s*require\s*\(\s*["']([^'"\s]+)["']\s*\)/g,
        jsSuffixRegExp = /\.js$/,
        currDirRegExp = /^\.\//,
        op = Object.prototype,
        ostring = op.toString,
        hasOwn = op.hasOwnProperty,
        ap = Array.prototype,
        isBrowser = !!(typeof window !== 'undefined' && typeof navigator !== 'undefined' && window.document),
        isWebWorker = !isBrowser && typeof importScripts !== 'undefined',
        //PS3 indicates loaded and complete, but need to wait for complete
        //specifically. Sequence is 'loading', 'loaded', execution,
        // then 'complete'. The UA check is unfortunate, but not sure how
        //to feature test w/o causing perf issues.
        readyRegExp = isBrowser && navigator.platform === 'PLAYSTATION 3' ?
                      /^complete$/ : /^(complete|loaded)$/,
        defContextName = '_',
        //Oh the tragedy, detecting opera. See the usage of isOpera for reason.
        isOpera = typeof opera !== 'undefined' && opera.toString() === '[object Opera]',
        contexts = {},
        cfg = {},
        globalDefQueue = [],
        useInteractive = false;

    function isFunction(it) {
        return ostring.call(it) === '[object Function]';
    }

    function isArray(it) {
        return ostring.call(it) === '[object Array]';
    }

    /**
     * Helper function for iterating over an array. If the func returns
     * a true value, it will break out of the loop.
     */
    function each(ary, func) {
        if (ary) {
            var i;
            for (i = 0; i < ary.length; i += 1) {
                if (ary[i] && func(ary[i], i, ary)) {
                    break;
                }
            }
        }
    }

    /**
     * Helper function for iterating over an array backwards. If the func
     * returns a true value, it will break out of the loop.
     */
    function eachReverse(ary, func) {
        if (ary) {
            var i;
            for (i = ary.length - 1; i > -1; i -= 1) {
                if (ary[i] && func(ary[i], i, ary)) {
                    break;
                }
            }
        }
    }

    function hasProp(obj, prop) {
        return hasOwn.call(obj, prop);
    }

    function getOwn(obj, prop) {
        return hasProp(obj, prop) && obj[prop];
    }

    /**
     * Cycles over properties in an object and calls a function for each
     * property value. If the function returns a truthy value, then the
     * iteration is stopped.
     */
    function eachProp(obj, func) {
        var prop;
        for (prop in obj) {
            if (hasProp(obj, prop)) {
                if (func(obj[prop], prop)) {
                    break;
                }
            }
        }
    }

    /**
     * Simple function to mix in properties from source into target,
     * but only if target does not already have a property of the same name.
     */
    function mixin(target, source, force, deepStringMixin) {
        if (source) {
            eachProp(source, function (value, prop) {
                if (force || !hasProp(target, prop)) {
                    if (deepStringMixin && typeof value === 'object' && value &&
                        !isArray(value) && !isFunction(value) &&
                        !(value instanceof RegExp)) {

                        if (!target[prop]) {
                            target[prop] = {};
                        }
                        mixin(target[prop], value, force, deepStringMixin);
                    } else {
                        target[prop] = value;
                    }
                }
            });
        }
        return target;
    }

    //Similar to Function.prototype.bind, but the 'this' object is specified
    //first, since it is easier to read/figure out what 'this' will be.
    function bind(obj, fn) {
        return function () {
            return fn.apply(obj, arguments);
        };
    }

    function scripts() {
        return document.getElementsByTagName('script');
    }

    function defaultOnError(err) {
        throw err;
    }

    //Allow getting a global that is expressed in
    //dot notation, like 'a.b.c'.
    function getGlobal(value) {
        if (!value) {
            return value;
        }
        var g = global;
        each(value.split('.'), function (part) {
            g = g[part];
        });
        return g;
    }

    /**
     * Constructs an error with a pointer to an URL with more information.
     * @param {String} id the error ID that maps to an ID on a web page.
     * @param {String} message human readable error.
     * @param {Error} [err] the original error, if there is one.
     *
     * @returns {Error}
     */
    function makeError(id, msg, err, requireModules) {
        var e = new Error(msg + '\nhttp://requirejs.org/docs/errors.html#' + id);
        e.requireType = id;
        e.requireModules = requireModules;
        if (err) {
            e.originalError = err;
        }
        return e;
    }

    if (typeof define !== 'undefined') {
        //If a define is already in play via another AMD loader,
        //do not overwrite.
        return;
    }

    if (typeof requirejs !== 'undefined') {
        if (isFunction(requirejs)) {
            //Do not overwrite an existing requirejs instance.
            return;
        }
        cfg = requirejs;
        requirejs = undefined;
    }

    //Allow for a require config object
    if (typeof require !== 'undefined' && !isFunction(require)) {
        //assume it is a config object.
        cfg = require;
        require = undefined;
    }

    function newContext(contextName) {
        var inCheckLoaded, Module, context, handlers,
            checkLoadedTimeoutId,
            config = {
                //Defaults. Do not set a default for map
                //config to speed up normalize(), which
                //will run faster if there is no default.
                waitSeconds: 7,
                baseUrl: './',
                paths: {},
                bundles: {},
                pkgs: {},
                shim: {},
                config: {}
            },
            registry = {},
            //registry of just enabled modules, to speed
            //cycle breaking code when lots of modules
            //are registered, but not activated.
            enabledRegistry = {},
            undefEvents = {},
            defQueue = [],
            defined = {},
            urlFetched = {},
            bundlesMap = {},
            requireCounter = 1,
            unnormalizedCounter = 1;

        /**
         * Trims the . and .. from an array of path segments.
         * It will keep a leading path segment if a .. will become
         * the first path segment, to help with module name lookups,
         * which act like paths, but can be remapped. But the end result,
         * all paths that use this function should look normalized.
         * NOTE: this method MODIFIES the input array.
         * @param {Array} ary the array of path segments.
         */
        function trimDots(ary) {
            var i, part;
            for (i = 0; i < ary.length; i++) {
                part = ary[i];
                if (part === '.') {
                    ary.splice(i, 1);
                    i -= 1;
                } else if (part === '..') {
                    // If at the start, or previous value is still ..,
                    // keep them so that when converted to a path it may
                    // still work when converted to a path, even though
                    // as an ID it is less than ideal. In larger point
                    // releases, may be better to just kick out an error.
                    if (i === 0 || (i === 1 && ary[2] === '..') || ary[i - 1] === '..') {
                        continue;
                    } else if (i > 0) {
                        ary.splice(i - 1, 2);
                        i -= 2;
                    }
                }
            }
        }

        /**
         * Given a relative module name, like ./something, normalize it to
         * a real name that can be mapped to a path.
         * @param {String} name the relative name
         * @param {String} baseName a real name that the name arg is relative
         * to.
         * @param {Boolean} applyMap apply the map config to the value. Should
         * only be done if this normalization is for a dependency ID.
         * @returns {String} normalized name
         */
        function normalize(name, baseName, applyMap) {
            var pkgMain, mapValue, nameParts, i, j, nameSegment, lastIndex,
                foundMap, foundI, foundStarMap, starI, normalizedBaseParts,
                baseParts = (baseName && baseName.split('/')),
                map = config.map,
                starMap = map && map['*'];

            //Adjust any relative paths.
            if (name) {
                name = name.split('/');
                lastIndex = name.length - 1;

                // If wanting node ID compatibility, strip .js from end
                // of IDs. Have to do this here, and not in nameToUrl
                // because node allows either .js or non .js to map
                // to same file.
                if (config.nodeIdCompat && jsSuffixRegExp.test(name[lastIndex])) {
                    name[lastIndex] = name[lastIndex].replace(jsSuffixRegExp, '');
                }

                // Starts with a '.' so need the baseName
                if (name[0].charAt(0) === '.' && baseParts) {
                    //Convert baseName to array, and lop off the last part,
                    //so that . matches that 'directory' and not name of the baseName's
                    //module. For instance, baseName of 'one/two/three', maps to
                    //'one/two/three.js', but we want the directory, 'one/two' for
                    //this normalization.
                    normalizedBaseParts = baseParts.slice(0, baseParts.length - 1);
                    name = normalizedBaseParts.concat(name);
                }

                trimDots(name);
                name = name.join('/');
            }

            //Apply map config if available.
            if (applyMap && map && (baseParts || starMap)) {
                nameParts = name.split('/');

                outerLoop: for (i = nameParts.length; i > 0; i -= 1) {
                    nameSegment = nameParts.slice(0, i).join('/');

                    if (baseParts) {
                        //Find the longest baseName segment match in the config.
                        //So, do joins on the biggest to smallest lengths of baseParts.
                        for (j = baseParts.length; j > 0; j -= 1) {
                            mapValue = getOwn(map, baseParts.slice(0, j).join('/'));

                            //baseName segment has config, find if it has one for
                            //this name.
                            if (mapValue) {
                                mapValue = getOwn(mapValue, nameSegment);
                                if (mapValue) {
                                    //Match, update name to the new value.
                                    foundMap = mapValue;
                                    foundI = i;
                                    break outerLoop;
                                }
                            }
                        }
                    }

                    //Check for a star map match, but just hold on to it,
                    //if there is a shorter segment match later in a matching
                    //config, then favor over this star map.
                    if (!foundStarMap && starMap && getOwn(starMap, nameSegment)) {
                        foundStarMap = getOwn(starMap, nameSegment);
                        starI = i;
                    }
                }

                if (!foundMap && foundStarMap) {
                    foundMap = foundStarMap;
                    foundI = starI;
                }

                if (foundMap) {
                    nameParts.splice(0, foundI, foundMap);
                    name = nameParts.join('/');
                }
            }

            // If the name points to a package's name, use
            // the package main instead.
            pkgMain = getOwn(config.pkgs, name);

            return pkgMain ? pkgMain : name;
        }

        function removeScript(name) {
            if (isBrowser) {
                each(scripts(), function (scriptNode) {
                    if (scriptNode.getAttribute('data-requiremodule') === name &&
                            scriptNode.getAttribute('data-requirecontext') === context.contextName) {
                        scriptNode.parentNode.removeChild(scriptNode);
                        return true;
                    }
                });
            }
        }

        function hasPathFallback(id) {
            var pathConfig = getOwn(config.paths, id);
            if (pathConfig && isArray(pathConfig) && pathConfig.length > 1) {
                //Pop off the first array value, since it failed, and
                //retry
                pathConfig.shift();
                context.require.undef(id);

                //Custom require that does not do map translation, since
                //ID is "absolute", already mapped/resolved.
                context.makeRequire(null, {
                    skipMap: true
                })([id]);

                return true;
            }
        }

        //Turns a plugin!resource to [plugin, resource]
        //with the plugin being undefined if the name
        //did not have a plugin prefix.
        function splitPrefix(name) {
            var prefix,
                index = name ? name.indexOf('!') : -1;
            if (index > -1) {
                prefix = name.substring(0, index);
                name = name.substring(index + 1, name.length);
            }
            return [prefix, name];
        }

        /**
         * Creates a module mapping that includes plugin prefix, module
         * name, and path. If parentModuleMap is provided it will
         * also normalize the name via require.normalize()
         *
         * @param {String} name the module name
         * @param {String} [parentModuleMap] parent module map
         * for the module name, used to resolve relative names.
         * @param {Boolean} isNormalized: is the ID already normalized.
         * This is true if this call is done for a define() module ID.
         * @param {Boolean} applyMap: apply the map config to the ID.
         * Should only be true if this map is for a dependency.
         *
         * @returns {Object}
         */
        function makeModuleMap(name, parentModuleMap, isNormalized, applyMap) {
            var url, pluginModule, suffix, nameParts,
                prefix = null,
                parentName = parentModuleMap ? parentModuleMap.name : null,
                originalName = name,
                isDefine = true,
                normalizedName = '';

            //If no name, then it means it is a require call, generate an
            //internal name.
            if (!name) {
                isDefine = false;
                name = '_@r' + (requireCounter += 1);
            }

            nameParts = splitPrefix(name);
            prefix = nameParts[0];
            name = nameParts[1];

            if (prefix) {
                prefix = normalize(prefix, parentName, applyMap);
                pluginModule = getOwn(defined, prefix);
            }

            //Account for relative paths if there is a base name.
            if (name) {
                if (prefix) {
                    if (pluginModule && pluginModule.normalize) {
                        //Plugin is loaded, use its normalize method.
                        normalizedName = pluginModule.normalize(name, function (name) {
                            return normalize(name, parentName, applyMap);
                        });
                    } else {
                        // If nested plugin references, then do not try to
                        // normalize, as it will not normalize correctly. This
                        // places a restriction on resourceIds, and the longer
                        // term solution is not to normalize until plugins are
                        // loaded and all normalizations to allow for async
                        // loading of a loader plugin. But for now, fixes the
                        // common uses. Details in #1131
                        normalizedName = name.indexOf('!') === -1 ?
                                         normalize(name, parentName, applyMap) :
                                         name;
                    }
                } else {
                    //A regular module.
                    normalizedName = normalize(name, parentName, applyMap);

                    //Normalized name may be a plugin ID due to map config
                    //application in normalize. The map config values must
                    //already be normalized, so do not need to redo that part.
                    nameParts = splitPrefix(normalizedName);
                    prefix = nameParts[0];
                    normalizedName = nameParts[1];
                    isNormalized = true;

                    url = context.nameToUrl(normalizedName);
                }
            }

            //If the id is a plugin id that cannot be determined if it needs
            //normalization, stamp it with a unique ID so two matching relative
            //ids that may conflict can be separate.
            suffix = prefix && !pluginModule && !isNormalized ?
                     '_unnormalized' + (unnormalizedCounter += 1) :
                     '';

            return {
                prefix: prefix,
                name: normalizedName,
                parentMap: parentModuleMap,
                unnormalized: !!suffix,
                url: url,
                originalName: originalName,
                isDefine: isDefine,
                id: (prefix ?
                        prefix + '!' + normalizedName :
                        normalizedName) + suffix
            };
        }

        function getModule(depMap) {
            var id = depMap.id,
                mod = getOwn(registry, id);

            if (!mod) {
                mod = registry[id] = new context.Module(depMap);
            }

            return mod;
        }

        function on(depMap, name, fn) {
            var id = depMap.id,
                mod = getOwn(registry, id);

            if (hasProp(defined, id) &&
                    (!mod || mod.defineEmitComplete)) {
                if (name === 'defined') {
                    fn(defined[id]);
                }
            } else {
                mod = getModule(depMap);
                if (mod.error && name === 'error') {
                    fn(mod.error);
                } else {
                    mod.on(name, fn);
                }
            }
        }

        function onError(err, errback) {
            var ids = err.requireModules,
                notified = false;

            if (errback) {
                errback(err);
            } else {
                each(ids, function (id) {
                    var mod = getOwn(registry, id);
                    if (mod) {
                        //Set error on module, so it skips timeout checks.
                        mod.error = err;
                        if (mod.events.error) {
                            notified = true;
                            mod.emit('error', err);
                        }
                    }
                });

                if (!notified) {
                    req.onError(err);
                }
            }
        }

        /**
         * Internal method to transfer globalQueue items to this context's
         * defQueue.
         */
        function takeGlobalQueue() {
            //Push all the globalDefQueue items into the context's defQueue
            if (globalDefQueue.length) {
                each(globalDefQueue, function(queueItem) {
                    var id = queueItem[0];
                    if (typeof id === 'string') {
                        context.defQueueMap[id] = true;
                    }
                    defQueue.push(queueItem);
                });
                globalDefQueue = [];
            }
        }

        handlers = {
            'require': function (mod) {
                if (mod.require) {
                    return mod.require;
                } else {
                    return (mod.require = context.makeRequire(mod.map));
                }
            },
            'exports': function (mod) {
                mod.usingExports = true;
                if (mod.map.isDefine) {
                    if (mod.exports) {
                        return (defined[mod.map.id] = mod.exports);
                    } else {
                        return (mod.exports = defined[mod.map.id] = {});
                    }
                }
            },
            'module': function (mod) {
                if (mod.module) {
                    return mod.module;
                } else {
                    return (mod.module = {
                        id: mod.map.id,
                        uri: mod.map.url,
                        config: function () {
                            return getOwn(config.config, mod.map.id) || {};
                        },
                        exports: mod.exports || (mod.exports = {})
                    });
                }
            }
        };

        function cleanRegistry(id) {
            //Clean up machinery used for waiting modules.
            delete registry[id];
            delete enabledRegistry[id];
        }

        function breakCycle(mod, traced, processed) {
            var id = mod.map.id;

            if (mod.error) {
                mod.emit('error', mod.error);
            } else {
                traced[id] = true;
                each(mod.depMaps, function (depMap, i) {
                    var depId = depMap.id,
                        dep = getOwn(registry, depId);

                    //Only force things that have not completed
                    //being defined, so still in the registry,
                    //and only if it has not been matched up
                    //in the module already.
                    if (dep && !mod.depMatched[i] && !processed[depId]) {
                        if (getOwn(traced, depId)) {
                            mod.defineDep(i, defined[depId]);
                            mod.check(); //pass false?
                        } else {
                            breakCycle(dep, traced, processed);
                        }
                    }
                });
                processed[id] = true;
            }
        }

        function checkLoaded() {
            var err, usingPathFallback,
                waitInterval = config.waitSeconds * 1000,
                //It is possible to disable the wait interval by using waitSeconds of 0.
                expired = waitInterval && (context.startTime + waitInterval) < new Date().getTime(),
                noLoads = [],
                reqCalls = [],
                stillLoading = false,
                needCycleCheck = true;

            //Do not bother if this call was a result of a cycle break.
            if (inCheckLoaded) {
                return;
            }

            inCheckLoaded = true;

            //Figure out the state of all the modules.
            eachProp(enabledRegistry, function (mod) {
                var map = mod.map,
                    modId = map.id;

                //Skip things that are not enabled or in error state.
                if (!mod.enabled) {
                    return;
                }

                if (!map.isDefine) {
                    reqCalls.push(mod);
                }

                if (!mod.error) {
                    //If the module should be executed, and it has not
                    //been inited and time is up, remember it.
                    if (!mod.inited && expired) {
                        if (hasPathFallback(modId)) {
                            usingPathFallback = true;
                            stillLoading = true;
                        } else {
                            noLoads.push(modId);
                            removeScript(modId);
                        }
                    } else if (!mod.inited && mod.fetched && map.isDefine) {
                        stillLoading = true;
                        if (!map.prefix) {
                            //No reason to keep looking for unfinished
                            //loading. If the only stillLoading is a
                            //plugin resource though, keep going,
                            //because it may be that a plugin resource
                            //is waiting on a non-plugin cycle.
                            return (needCycleCheck = false);
                        }
                    }
                }
            });

            if (expired && noLoads.length) {
                //If wait time expired, throw error of unloaded modules.
                err = makeError('timeout', 'Load timeout for modules: ' + noLoads, null, noLoads);
                err.contextName = context.contextName;
                return onError(err);
            }

            //Not expired, check for a cycle.
            if (needCycleCheck) {
                each(reqCalls, function (mod) {
                    breakCycle(mod, {}, {});
                });
            }

            //If still waiting on loads, and the waiting load is something
            //other than a plugin resource, or there are still outstanding
            //scripts, then just try back later.
            if ((!expired || usingPathFallback) && stillLoading) {
                //Something is still waiting to load. Wait for it, but only
                //if a timeout is not already in effect.
                if ((isBrowser || isWebWorker) && !checkLoadedTimeoutId) {
                    checkLoadedTimeoutId = setTimeout(function () {
                        checkLoadedTimeoutId = 0;
                        checkLoaded();
                    }, 50);
                }
            }

            inCheckLoaded = false;
        }

        Module = function (map) {
            this.events = getOwn(undefEvents, map.id) || {};
            this.map = map;
            this.shim = getOwn(config.shim, map.id);
            this.depExports = [];
            this.depMaps = [];
            this.depMatched = [];
            this.pluginMaps = {};
            this.depCount = 0;

            /* this.exports this.factory
               this.depMaps = [],
               this.enabled, this.fetched
            */
        };

        Module.prototype = {
            init: function (depMaps, factory, errback, options) {
                options = options || {};

                //Do not do more inits if already done. Can happen if there
                //are multiple define calls for the same module. That is not
                //a normal, common case, but it is also not unexpected.
                if (this.inited) {
                    return;
                }

                this.factory = factory;

                if (errback) {
                    //Register for errors on this module.
                    this.on('error', errback);
                } else if (this.events.error) {
                    //If no errback already, but there are error listeners
                    //on this module, set up an errback to pass to the deps.
                    errback = bind(this, function (err) {
                        this.emit('error', err);
                    });
                }

                //Do a copy of the dependency array, so that
                //source inputs are not modified. For example
                //"shim" deps are passed in here directly, and
                //doing a direct modification of the depMaps array
                //would affect that config.
                this.depMaps = depMaps && depMaps.slice(0);

                this.errback = errback;

                //Indicate this module has be initialized
                this.inited = true;

                this.ignore = options.ignore;

                //Could have option to init this module in enabled mode,
                //or could have been previously marked as enabled. However,
                //the dependencies are not known until init is called. So
                //if enabled previously, now trigger dependencies as enabled.
                if (options.enabled || this.enabled) {
                    //Enable this module and dependencies.
                    //Will call this.check()
                    this.enable();
                } else {
                    this.check();
                }
            },

            defineDep: function (i, depExports) {
                //Because of cycles, defined callback for a given
                //export can be called more than once.
                if (!this.depMatched[i]) {
                    this.depMatched[i] = true;
                    this.depCount -= 1;
                    this.depExports[i] = depExports;
                }
            },

            fetch: function () {
                if (this.fetched) {
                    return;
                }
                this.fetched = true;

                context.startTime = (new Date()).getTime();

                var map = this.map;

                //If the manager is for a plugin managed resource,
                //ask the plugin to load it now.
                if (this.shim) {
                    context.makeRequire(this.map, {
                        enableBuildCallback: true
                    })(this.shim.deps || [], bind(this, function () {
                        return map.prefix ? this.callPlugin() : this.load();
                    }));
                } else {
                    //Regular dependency.
                    return map.prefix ? this.callPlugin() : this.load();
                }
            },

            load: function () {
                var url = this.map.url;

                //Regular dependency.
                if (!urlFetched[url]) {
                    urlFetched[url] = true;
                    context.load(this.map.id, url);
                }
            },

            /**
             * Checks if the module is ready to define itself, and if so,
             * define it.
             */
            check: function () {
                if (!this.enabled || this.enabling) {
                    return;
                }

                var err, cjsModule,
                    id = this.map.id,
                    depExports = this.depExports,
                    exports = this.exports,
                    factory = this.factory;

                if (!this.inited) {
                    // Only fetch if not already in the defQueue.
                    if (!hasProp(context.defQueueMap, id)) {
                        this.fetch();
                    }
                } else if (this.error) {
                    this.emit('error', this.error);
                } else if (!this.defining) {
                    //The factory could trigger another require call
                    //that would result in checking this module to
                    //define itself again. If already in the process
                    //of doing that, skip this work.
                    this.defining = true;

                    if (this.depCount < 1 && !this.defined) {
                        if (isFunction(factory)) {
                            try {
                                exports = context.execCb(id, factory, depExports, exports);
                            } catch (e) {
                                err = e;
                            }

                            // Favor return value over exports. If node/cjs in play,
                            // then will not have a return value anyway. Favor
                            // module.exports assignment over exports object.
                            if (this.map.isDefine && exports === undefined) {
                                cjsModule = this.module;
                                if (cjsModule) {
                                    exports = cjsModule.exports;
                                } else if (this.usingExports) {
                                    //exports already set the defined value.
                                    exports = this.exports;
                                }
                            }

                            if (err) {
                                // If there is an error listener, favor passing
                                // to that instead of throwing an error. However,
                                // only do it for define()'d  modules. require
                                // errbacks should not be called for failures in
                                // their callbacks (#699). However if a global
                                // onError is set, use that.
                                if ((this.events.error && this.map.isDefine) ||
                                    req.onError !== defaultOnError) {
                                    err.requireMap = this.map;
                                    err.requireModules = this.map.isDefine ? [this.map.id] : null;
                                    err.requireType = this.map.isDefine ? 'define' : 'require';
                                    return onError((this.error = err));
                                } else if (typeof console !== 'undefined' &&
                                           console.error) {
                                    // Log the error for debugging. If promises could be
                                    // used, this would be different, but making do.
                                    console.error(err);
                                } else {
                                    // Do not want to completely lose the error. While this
                                    // will mess up processing and lead to similar results
                                    // as bug 1440, it at least surfaces the error.
                                    req.onError(err);
                                }
                            }
                        } else {
                            //Just a literal value
                            exports = factory;
                        }

                        this.exports = exports;

                        if (this.map.isDefine && !this.ignore) {
                            defined[id] = exports;

                            if (req.onResourceLoad) {
                                var resLoadMaps = [];
                                each(this.depMaps, function (depMap) {
                                    resLoadMaps.push(depMap.normalizedMap || depMap);
                                });
                                req.onResourceLoad(context, this.map, resLoadMaps);
                            }
                        }

                        //Clean up
                        cleanRegistry(id);

                        this.defined = true;
                    }

                    //Finished the define stage. Allow calling check again
                    //to allow define notifications below in the case of a
                    //cycle.
                    this.defining = false;

                    if (this.defined && !this.defineEmitted) {
                        this.defineEmitted = true;
                        this.emit('defined', this.exports);
                        this.defineEmitComplete = true;
                    }

                }
            },

            callPlugin: function () {
                var map = this.map,
                    id = map.id,
                    //Map already normalized the prefix.
                    pluginMap = makeModuleMap(map.prefix);

                //Mark this as a dependency for this plugin, so it
                //can be traced for cycles.
                this.depMaps.push(pluginMap);

                on(pluginMap, 'defined', bind(this, function (plugin) {
                    var load, normalizedMap, normalizedMod,
                        bundleId = getOwn(bundlesMap, this.map.id),
                        name = this.map.name,
                        parentName = this.map.parentMap ? this.map.parentMap.name : null,
                        localRequire = context.makeRequire(map.parentMap, {
                            enableBuildCallback: true
                        });

                    //If current map is not normalized, wait for that
                    //normalized name to load instead of continuing.
                    if (this.map.unnormalized) {
                        //Normalize the ID if the plugin allows it.
                        if (plugin.normalize) {
                            name = plugin.normalize(name, function (name) {
                                return normalize(name, parentName, true);
                            }) || '';
                        }

                        //prefix and name should already be normalized, no need
                        //for applying map config again either.
                        normalizedMap = makeModuleMap(map.prefix + '!' + name,
                                                      this.map.parentMap);
                        on(normalizedMap,
                            'defined', bind(this, function (value) {
                                this.map.normalizedMap = normalizedMap;
                                this.init([], function () { return value; }, null, {
                                    enabled: true,
                                    ignore: true
                                });
                            }));

                        normalizedMod = getOwn(registry, normalizedMap.id);
                        if (normalizedMod) {
                            //Mark this as a dependency for this plugin, so it
                            //can be traced for cycles.
                            this.depMaps.push(normalizedMap);

                            if (this.events.error) {
                                normalizedMod.on('error', bind(this, function (err) {
                                    this.emit('error', err);
                                }));
                            }
                            normalizedMod.enable();
                        }

                        return;
                    }

                    //If a paths config, then just load that file instead to
                    //resolve the plugin, as it is built into that paths layer.
                    if (bundleId) {
                        this.map.url = context.nameToUrl(bundleId);
                        this.load();
                        return;
                    }

                    load = bind(this, function (value) {
                        this.init([], function () { return value; }, null, {
                            enabled: true
                        });
                    });

                    load.error = bind(this, function (err) {
                        this.inited = true;
                        this.error = err;
                        err.requireModules = [id];

                        //Remove temp unnormalized modules for this module,
                        //since they will never be resolved otherwise now.
                        eachProp(registry, function (mod) {
                            if (mod.map.id.indexOf(id + '_unnormalized') === 0) {
                                cleanRegistry(mod.map.id);
                            }
                        });

                        onError(err);
                    });

                    //Allow plugins to load other code without having to know the
                    //context or how to 'complete' the load.
                    load.fromText = bind(this, function (text, textAlt) {
                        /*jslint evil: true */
                        var moduleName = map.name,
                            moduleMap = makeModuleMap(moduleName),
                            hasInteractive = useInteractive;

                        //As of 2.1.0, support just passing the text, to reinforce
                        //fromText only being called once per resource. Still
                        //support old style of passing moduleName but discard
                        //that moduleName in favor of the internal ref.
                        if (textAlt) {
                            text = textAlt;
                        }

                        //Turn off interactive script matching for IE for any define
                        //calls in the text, then turn it back on at the end.
                        if (hasInteractive) {
                            useInteractive = false;
                        }

                        //Prime the system by creating a module instance for
                        //it.
                        getModule(moduleMap);

                        //Transfer any config to this other module.
                        if (hasProp(config.config, id)) {
                            config.config[moduleName] = config.config[id];
                        }

                        try {
                            req.exec(text);
                        } catch (e) {
                            return onError(makeError('fromtexteval',
                                             'fromText eval for ' + id +
                                            ' failed: ' + e,
                                             e,
                                             [id]));
                        }

                        if (hasInteractive) {
                            useInteractive = true;
                        }

                        //Mark this as a dependency for the plugin
                        //resource
                        this.depMaps.push(moduleMap);

                        //Support anonymous modules.
                        context.completeLoad(moduleName);

                        //Bind the value of that module to the value for this
                        //resource ID.
                        localRequire([moduleName], load);
                    });

                    //Use parentName here since the plugin's name is not reliable,
                    //could be some weird string with no path that actually wants to
                    //reference the parentName's path.
                    plugin.load(map.name, localRequire, load, config);
                }));

                context.enable(pluginMap, this);
                this.pluginMaps[pluginMap.id] = pluginMap;
            },

            enable: function () {
                enabledRegistry[this.map.id] = this;
                this.enabled = true;

                //Set flag mentioning that the module is enabling,
                //so that immediate calls to the defined callbacks
                //for dependencies do not trigger inadvertent load
                //with the depCount still being zero.
                this.enabling = true;

                //Enable each dependency
                each(this.depMaps, bind(this, function (depMap, i) {
                    var id, mod, handler;

                    if (typeof depMap === 'string') {
                        //Dependency needs to be converted to a depMap
                        //and wired up to this module.
                        depMap = makeModuleMap(depMap,
                                               (this.map.isDefine ? this.map : this.map.parentMap),
                                               false,
                                               !this.skipMap);
                        this.depMaps[i] = depMap;

                        handler = getOwn(handlers, depMap.id);

                        if (handler) {
                            this.depExports[i] = handler(this);
                            return;
                        }

                        this.depCount += 1;

                        on(depMap, 'defined', bind(this, function (depExports) {
                            if (this.undefed) {
                                return;
                            }
                            this.defineDep(i, depExports);
                            this.check();
                        }));

                        if (this.errback) {
                            on(depMap, 'error', bind(this, this.errback));
                        } else if (this.events.error) {
                            // No direct errback on this module, but something
                            // else is listening for errors, so be sure to
                            // propagate the error correctly.
                            on(depMap, 'error', bind(this, function(err) {
                                this.emit('error', err);
                            }));
                        }
                    }

                    id = depMap.id;
                    mod = registry[id];

                    //Skip special modules like 'require', 'exports', 'module'
                    //Also, don't call enable if it is already enabled,
                    //important in circular dependency cases.
                    if (!hasProp(handlers, id) && mod && !mod.enabled) {
                        context.enable(depMap, this);
                    }
                }));

                //Enable each plugin that is used in
                //a dependency
                eachProp(this.pluginMaps, bind(this, function (pluginMap) {
                    var mod = getOwn(registry, pluginMap.id);
                    if (mod && !mod.enabled) {
                        context.enable(pluginMap, this);
                    }
                }));

                this.enabling = false;

                this.check();
            },

            on: function (name, cb) {
                var cbs = this.events[name];
                if (!cbs) {
                    cbs = this.events[name] = [];
                }
                cbs.push(cb);
            },

            emit: function (name, evt) {
                each(this.events[name], function (cb) {
                    cb(evt);
                });
                if (name === 'error') {
                    //Now that the error handler was triggered, remove
                    //the listeners, since this broken Module instance
                    //can stay around for a while in the registry.
                    delete this.events[name];
                }
            }
        };

        function callGetModule(args) {
            //Skip modules already defined.
            if (!hasProp(defined, args[0])) {
                getModule(makeModuleMap(args[0], null, true)).init(args[1], args[2]);
            }
        }

        function removeListener(node, func, name, ieName) {
            //Favor detachEvent because of IE9
            //issue, see attachEvent/addEventListener comment elsewhere
            //in this file.
            if (node.detachEvent && !isOpera) {
                //Probably IE. If not it will throw an error, which will be
                //useful to know.
                if (ieName) {
                    node.detachEvent(ieName, func);
                }
            } else {
                node.removeEventListener(name, func, false);
            }
        }

        /**
         * Given an event from a script node, get the requirejs info from it,
         * and then removes the event listeners on the node.
         * @param {Event} evt
         * @returns {Object}
         */
        function getScriptData(evt) {
            //Using currentTarget instead of target for Firefox 2.0's sake. Not
            //all old browsers will be supported, but this one was easy enough
            //to support and still makes sense.
            var node = evt.currentTarget || evt.srcElement;

            //Remove the listeners once here.
            removeListener(node, context.onScriptLoad, 'load', 'onreadystatechange');
            removeListener(node, context.onScriptError, 'error');

            return {
                node: node,
                id: node && node.getAttribute('data-requiremodule')
            };
        }

        function intakeDefines() {
            var args;

            //Any defined modules in the global queue, intake them now.
            takeGlobalQueue();

            //Make sure any remaining defQueue items get properly processed.
            while (defQueue.length) {
                args = defQueue.shift();
                if (args[0] === null) {
                    return onError(makeError('mismatch', 'Mismatched anonymous define() module: ' +
                        args[args.length - 1]));
                } else {
                    //args are id, deps, factory. Should be normalized by the
                    //define() function.
                    callGetModule(args);
                }
            }
            context.defQueueMap = {};
        }

        context = {
            config: config,
            contextName: contextName,
            registry: registry,
            defined: defined,
            urlFetched: urlFetched,
            defQueue: defQueue,
            defQueueMap: {},
            Module: Module,
            makeModuleMap: makeModuleMap,
            nextTick: req.nextTick,
            onError: onError,

            /**
             * Set a configuration for the context.
             * @param {Object} cfg config object to integrate.
             */
            configure: function (cfg) {
                //Make sure the baseUrl ends in a slash.
                if (cfg.baseUrl) {
                    if (cfg.baseUrl.charAt(cfg.baseUrl.length - 1) !== '/') {
                        cfg.baseUrl += '/';
                    }
                }

                //Save off the paths since they require special processing,
                //they are additive.
                var shim = config.shim,
                    objs = {
                        paths: true,
                        bundles: true,
                        config: true,
                        map: true
                    };

                eachProp(cfg, function (value, prop) {
                    if (objs[prop]) {
                        if (!config[prop]) {
                            config[prop] = {};
                        }
                        mixin(config[prop], value, true, true);
                    } else {
                        config[prop] = value;
                    }
                });

                //Reverse map the bundles
                if (cfg.bundles) {
                    eachProp(cfg.bundles, function (value, prop) {
                        each(value, function (v) {
                            if (v !== prop) {
                                bundlesMap[v] = prop;
                            }
                        });
                    });
                }

                //Merge shim
                if (cfg.shim) {
                    eachProp(cfg.shim, function (value, id) {
                        //Normalize the structure
                        if (isArray(value)) {
                            value = {
                                deps: value
                            };
                        }
                        if ((value.exports || value.init) && !value.exportsFn) {
                            value.exportsFn = context.makeShimExports(value);
                        }
                        shim[id] = value;
                    });
                    config.shim = shim;
                }

                //Adjust packages if necessary.
                if (cfg.packages) {
                    each(cfg.packages, function (pkgObj) {
                        var location, name;

                        pkgObj = typeof pkgObj === 'string' ? {name: pkgObj} : pkgObj;

                        name = pkgObj.name;
                        location = pkgObj.location;
                        if (location) {
                            config.paths[name] = pkgObj.location;
                        }

                        //Save pointer to main module ID for pkg name.
                        //Remove leading dot in main, so main paths are normalized,
                        //and remove any trailing .js, since different package
                        //envs have different conventions: some use a module name,
                        //some use a file name.
                        config.pkgs[name] = pkgObj.name + '/' + (pkgObj.main || 'main')
                                     .replace(currDirRegExp, '')
                                     .replace(jsSuffixRegExp, '');
                    });
                }

                //If there are any "waiting to execute" modules in the registry,
                //update the maps for them, since their info, like URLs to load,
                //may have changed.
                eachProp(registry, function (mod, id) {
                    //If module already has init called, since it is too
                    //late to modify them, and ignore unnormalized ones
                    //since they are transient.
                    if (!mod.inited && !mod.map.unnormalized) {
                        mod.map = makeModuleMap(id, null, true);
                    }
                });

                //If a deps array or a config callback is specified, then call
                //require with those args. This is useful when require is defined as a
                //config object before require.js is loaded.
                if (cfg.deps || cfg.callback) {
                    context.require(cfg.deps || [], cfg.callback);
                }
            },

            makeShimExports: function (value) {
                function fn() {
                    var ret;
                    if (value.init) {
                        ret = value.init.apply(global, arguments);
                    }
                    return ret || (value.exports && getGlobal(value.exports));
                }
                return fn;
            },

            makeRequire: function (relMap, options) {
                options = options || {};

                function localRequire(deps, callback, errback) {
                    var id, map, requireMod;

                    if (options.enableBuildCallback && callback && isFunction(callback)) {
                        callback.__requireJsBuild = true;
                    }

                    if (typeof deps === 'string') {
                        if (isFunction(callback)) {
                            //Invalid call
                            return onError(makeError('requireargs', 'Invalid require call'), errback);
                        }

                        //If require|exports|module are requested, get the
                        //value for them from the special handlers. Caveat:
                        //this only works while module is being defined.
                        if (relMap && hasProp(handlers, deps)) {
                            return handlers[deps](registry[relMap.id]);
                        }

                        //Synchronous access to one module. If require.get is
                        //available (as in the Node adapter), prefer that.
                        if (req.get) {
                            return req.get(context, deps, relMap, localRequire);
                        }

                        //Normalize module name, if it contains . or ..
                        map = makeModuleMap(deps, relMap, false, true);
                        id = map.id;

                        if (!hasProp(defined, id)) {
                            return onError(makeError('notloaded', 'Module name "' +
                                        id +
                                        '" has not been loaded yet for context: ' +
                                        contextName +
                                        (relMap ? '' : '. Use require([])')));
                        }
                        return defined[id];
                    }

                    //Grab defines waiting in the global queue.
                    intakeDefines();

                    //Mark all the dependencies as needing to be loaded.
                    context.nextTick(function () {
                        //Some defines could have been added since the
                        //require call, collect them.
                        intakeDefines();

                        requireMod = getModule(makeModuleMap(null, relMap));

                        //Store if map config should be applied to this require
                        //call for dependencies.
                        requireMod.skipMap = options.skipMap;

                        requireMod.init(deps, callback, errback, {
                            enabled: true
                        });

                        checkLoaded();
                    });

                    return localRequire;
                }

                mixin(localRequire, {
                    isBrowser: isBrowser,

                    /**
                     * Converts a module name + .extension into an URL path.
                     * *Requires* the use of a module name. It does not support using
                     * plain URLs like nameToUrl.
                     */
                    toUrl: function (moduleNamePlusExt) {
                        var ext,
                            index = moduleNamePlusExt.lastIndexOf('.'),
                            segment = moduleNamePlusExt.split('/')[0],
                            isRelative = segment === '.' || segment === '..';

                        //Have a file extension alias, and it is not the
                        //dots from a relative path.
                        if (index !== -1 && (!isRelative || index > 1)) {
                            ext = moduleNamePlusExt.substring(index, moduleNamePlusExt.length);
                            moduleNamePlusExt = moduleNamePlusExt.substring(0, index);
                        }

                        return context.nameToUrl(normalize(moduleNamePlusExt,
                                                relMap && relMap.id, true), ext,  true);
                    },

                    defined: function (id) {
                        return hasProp(defined, makeModuleMap(id, relMap, false, true).id);
                    },

                    specified: function (id) {
                        id = makeModuleMap(id, relMap, false, true).id;
                        return hasProp(defined, id) || hasProp(registry, id);
                    }
                });

                //Only allow undef on top level require calls
                if (!relMap) {
                    localRequire.undef = function (id) {
                        //Bind any waiting define() calls to this context,
                        //fix for #408
                        takeGlobalQueue();

                        var map = makeModuleMap(id, relMap, true),
                            mod = getOwn(registry, id);

                        mod.undefed = true;
                        removeScript(id);

                        delete defined[id];
                        delete urlFetched[map.url];
                        delete undefEvents[id];

                        //Clean queued defines too. Go backwards
                        //in array so that the splices do not
                        //mess up the iteration.
                        eachReverse(defQueue, function(args, i) {
                            if (args[0] === id) {
                                defQueue.splice(i, 1);
                            }
                        });
                        delete context.defQueueMap[id];

                        if (mod) {
                            //Hold on to listeners in case the
                            //module will be attempted to be reloaded
                            //using a different config.
                            if (mod.events.defined) {
                                undefEvents[id] = mod.events;
                            }

                            cleanRegistry(id);
                        }
                    };
                }

                return localRequire;
            },

            /**
             * Called to enable a module if it is still in the registry
             * awaiting enablement. A second arg, parent, the parent module,
             * is passed in for context, when this method is overridden by
             * the optimizer. Not shown here to keep code compact.
             */
            enable: function (depMap) {
                var mod = getOwn(registry, depMap.id);
                if (mod) {
                    getModule(depMap).enable();
                }
            },

            /**
             * Internal method used by environment adapters to complete a load event.
             * A load event could be a script load or just a load pass from a synchronous
             * load call.
             * @param {String} moduleName the name of the module to potentially complete.
             */
            completeLoad: function (moduleName) {
                var found, args, mod,
                    shim = getOwn(config.shim, moduleName) || {},
                    shExports = shim.exports;

                takeGlobalQueue();

                while (defQueue.length) {
                    args = defQueue.shift();
                    if (args[0] === null) {
                        args[0] = moduleName;
                        //If already found an anonymous module and bound it
                        //to this name, then this is some other anon module
                        //waiting for its completeLoad to fire.
                        if (found) {
                            break;
                        }
                        found = true;
                    } else if (args[0] === moduleName) {
                        //Found matching define call for this script!
                        found = true;
                    }

                    callGetModule(args);
                }
                context.defQueueMap = {};

                //Do this after the cycle of callGetModule in case the result
                //of those calls/init calls changes the registry.
                mod = getOwn(registry, moduleName);

                if (!found && !hasProp(defined, moduleName) && mod && !mod.inited) {
                    if (config.enforceDefine && (!shExports || !getGlobal(shExports))) {
                        if (hasPathFallback(moduleName)) {
                            return;
                        } else {
                            return onError(makeError('nodefine',
                                             'No define call for ' + moduleName,
                                             null,
                                             [moduleName]));
                        }
                    } else {
                        //A script that does not call define(), so just simulate
                        //the call for it.
                        callGetModule([moduleName, (shim.deps || []), shim.exportsFn]);
                    }
                }

                checkLoaded();
            },

            /**
             * Converts a module name to a file path. Supports cases where
             * moduleName may actually be just an URL.
             * Note that it **does not** call normalize on the moduleName,
             * it is assumed to have already been normalized. This is an
             * internal API, not a public one. Use toUrl for the public API.
             */
            nameToUrl: function (moduleName, ext, skipExt) {
                var paths, syms, i, parentModule, url,
                    parentPath, bundleId,
                    pkgMain = getOwn(config.pkgs, moduleName);

                if (pkgMain) {
                    moduleName = pkgMain;
                }

                bundleId = getOwn(bundlesMap, moduleName);

                if (bundleId) {
                    return context.nameToUrl(bundleId, ext, skipExt);
                }

                //If a colon is in the URL, it indicates a protocol is used and it is just
                //an URL to a file, or if it starts with a slash, contains a query arg (i.e. ?)
                //or ends with .js, then assume the user meant to use an url and not a module id.
                //The slash is important for protocol-less URLs as well as full paths.
                if (req.jsExtRegExp.test(moduleName)) {
                    //Just a plain path, not module name lookup, so just return it.
                    //Add extension if it is included. This is a bit wonky, only non-.js things pass
                    //an extension, this method probably needs to be reworked.
                    url = moduleName + (ext || '');
                } else {
                    //A module that needs to be converted to a path.
                    paths = config.paths;

                    syms = moduleName.split('/');
                    //For each module name segment, see if there is a path
                    //registered for it. Start with most specific name
                    //and work up from it.
                    for (i = syms.length; i > 0; i -= 1) {
                        parentModule = syms.slice(0, i).join('/');

                        parentPath = getOwn(paths, parentModule);
                        if (parentPath) {
                            //If an array, it means there are a few choices,
                            //Choose the one that is desired
                            if (isArray(parentPath)) {
                                parentPath = parentPath[0];
                            }
                            syms.splice(0, i, parentPath);
                            break;
                        }
                    }

                    //Join the path parts together, then figure out if baseUrl is needed.
                    url = syms.join('/');
                    url += (ext || (/^data\:|\?/.test(url) || skipExt ? '' : '.js'));
                    url = (url.charAt(0) === '/' || url.match(/^[\w\+\.\-]+:/) ? '' : config.baseUrl) + url;
                }

                return config.urlArgs ? url +
                                        ((url.indexOf('?') === -1 ? '?' : '&') +
                                         config.urlArgs) : url;
            },

            //Delegates to req.load. Broken out as a separate function to
            //allow overriding in the optimizer.
            load: function (id, url) {
                req.load(context, id, url);
            },

            /**
             * Executes a module callback function. Broken out as a separate function
             * solely to allow the build system to sequence the files in the built
             * layer in the right sequence.
             *
             * @private
             */
            execCb: function (name, callback, args, exports) {
                return callback.apply(exports, args);
            },

            /**
             * callback for script loads, used to check status of loading.
             *
             * @param {Event} evt the event from the browser for the script
             * that was loaded.
             */
            onScriptLoad: function (evt) {
                //Using currentTarget instead of target for Firefox 2.0's sake. Not
                //all old browsers will be supported, but this one was easy enough
                //to support and still makes sense.
                if (evt.type === 'load' ||
                        (readyRegExp.test((evt.currentTarget || evt.srcElement).readyState))) {
                    //Reset interactive script so a script node is not held onto for
                    //to long.
                    interactiveScript = null;

                    //Pull out the name of the module and the context.
                    var data = getScriptData(evt);
                    context.completeLoad(data.id);
                }
            },

            /**
             * Callback for script errors.
             */
            onScriptError: function (evt) {
                var data = getScriptData(evt);
                if (!hasPathFallback(data.id)) {
                    var parents = [];
                    eachProp(registry, function(value, key) {
                        if (key.indexOf('_@r') !== 0) {
                            each(value.depMaps, function(depMap) {
                                if (depMap.id === data.id) {
                                    parents.push(key);
                                }
                                return true;
                            });
                        }
                    });
                    return onError(makeError('scripterror', 'Script error for "' + data.id +
                                             (parents.length ?
                                             '", needed by: ' + parents.join(', ') :
                                             '"'), evt, [data.id]));
                }
            }
        };

        context.require = context.makeRequire();
        return context;
    }

    /**
     * Main entry point.
     *
     * If the only argument to require is a string, then the module that
     * is represented by that string is fetched for the appropriate context.
     *
     * If the first argument is an array, then it will be treated as an array
     * of dependency string names to fetch. An optional function callback can
     * be specified to execute when all of those dependencies are available.
     *
     * Make a local req variable to help Caja compliance (it assumes things
     * on a require that are not standardized), and to give a short
     * name for minification/local scope use.
     */
    req = requirejs = function (deps, callback, errback, optional) {

        //Find the right context, use default
        var context, config,
            contextName = defContextName;

        // Determine if have config object in the call.
        if (!isArray(deps) && typeof deps !== 'string') {
            // deps is a config object
            config = deps;
            if (isArray(callback)) {
                // Adjust args if there are dependencies
                deps = callback;
                callback = errback;
                errback = optional;
            } else {
                deps = [];
            }
        }

        if (config && config.context) {
            contextName = config.context;
        }

        context = getOwn(contexts, contextName);
        if (!context) {
            context = contexts[contextName] = req.s.newContext(contextName);
        }

        if (config) {
            context.configure(config);
        }

        return context.require(deps, callback, errback);
    };

    /**
     * Support require.config() to make it easier to cooperate with other
     * AMD loaders on globally agreed names.
     */
    req.config = function (config) {
        return req(config);
    };

    /**
     * Execute something after the current tick
     * of the event loop. Override for other envs
     * that have a better solution than setTimeout.
     * @param  {Function} fn function to execute later.
     */
    req.nextTick = typeof setTimeout !== 'undefined' ? function (fn) {
        setTimeout(fn, 4);
    } : function (fn) { fn(); };

    /**
     * Export require as a global, but only if it does not already exist.
     */
    if (!require) {
        require = req;
    }

    req.version = version;

    //Used to filter out dependencies that are already paths.
    req.jsExtRegExp = /^\/|:|\?|\.js$/;
    req.isBrowser = isBrowser;
    s = req.s = {
        contexts: contexts,
        newContext: newContext
    };

    //Create default context.
    req({});

    //Exports some context-sensitive methods on global require.
    each([
        'toUrl',
        'undef',
        'defined',
        'specified'
    ], function (prop) {
        //Reference from contexts instead of early binding to default context,
        //so that during builds, the latest instance of the default context
        //with its config gets used.
        req[prop] = function () {
            var ctx = contexts[defContextName];
            return ctx.require[prop].apply(ctx, arguments);
        };
    });

    if (isBrowser) {
        head = s.head = document.getElementsByTagName('head')[0];
        //If BASE tag is in play, using appendChild is a problem for IE6.
        //When that browser dies, this can be removed. Details in this jQuery bug:
        //http://dev.jquery.com/ticket/2709
        baseElement = document.getElementsByTagName('base')[0];
        if (baseElement) {
            head = s.head = baseElement.parentNode;
        }
    }

    /**
     * Any errors that require explicitly generates will be passed to this
     * function. Intercept/override it if you want custom error handling.
     * @param {Error} err the error object.
     */
    req.onError = defaultOnError;

    /**
     * Creates the node for the load command. Only used in browser envs.
     */
    req.createNode = function (config, moduleName, url) {
        var node = config.xhtml ?
                document.createElementNS('http://www.w3.org/1999/xhtml', 'html:script') :
                document.createElement('script');
        node.type = config.scriptType || 'text/javascript';
        node.charset = 'utf-8';
        node.async = true;
        return node;
    };

    /**
     * Does the request to load a module for the browser case.
     * Make this a separate function to allow other environments
     * to override it.
     *
     * @param {Object} context the require context to find state.
     * @param {String} moduleName the name of the module.
     * @param {Object} url the URL to the module.
     */
    req.load = function (context, moduleName, url) {
        var config = (context && context.config) || {},
            node;
        if (isBrowser) {
            //In the browser so use a script tag
            node = req.createNode(config, moduleName, url);
            if (config.onNodeCreated) {
                config.onNodeCreated(node, config, moduleName, url);
            }

            node.setAttribute('data-requirecontext', context.contextName);
            node.setAttribute('data-requiremodule', moduleName);

            //Set up load listener. Test attachEvent first because IE9 has
            //a subtle issue in its addEventListener and script onload firings
            //that do not match the behavior of all other browsers with
            //addEventListener support, which fire the onload event for a
            //script right after the script execution. See:
            //https://connect.microsoft.com/IE/feedback/details/648057/script-onload-event-is-not-fired-immediately-after-script-execution
            //UNFORTUNATELY Opera implements attachEvent but does not follow the script
            //script execution mode.
            if (node.attachEvent &&
                    //Check if node.attachEvent is artificially added by custom script or
                    //natively supported by browser
                    //read https://github.com/jrburke/requirejs/issues/187
                    //if we can NOT find [native code] then it must NOT natively supported.
                    //in IE8, node.attachEvent does not have toString()
                    //Note the test for "[native code" with no closing brace, see:
                    //https://github.com/jrburke/requirejs/issues/273
                    !(node.attachEvent.toString && node.attachEvent.toString().indexOf('[native code') < 0) &&
                    !isOpera) {
                //Probably IE. IE (at least 6-8) do not fire
                //script onload right after executing the script, so
                //we cannot tie the anonymous define call to a name.
                //However, IE reports the script as being in 'interactive'
                //readyState at the time of the define call.
                useInteractive = true;

                node.attachEvent('onreadystatechange', context.onScriptLoad);
                //It would be great to add an error handler here to catch
                //404s in IE9+. However, onreadystatechange will fire before
                //the error handler, so that does not help. If addEventListener
                //is used, then IE will fire error before load, but we cannot
                //use that pathway given the connect.microsoft.com issue
                //mentioned above about not doing the 'script execute,
                //then fire the script load event listener before execute
                //next script' that other browsers do.
                //Best hope: IE10 fixes the issues,
                //and then destroys all installs of IE 6-9.
                //node.attachEvent('onerror', context.onScriptError);
            } else {
                node.addEventListener('load', context.onScriptLoad, false);
                node.addEventListener('error', context.onScriptError, false);
            }
            node.src = url;

            //For some cache cases in IE 6-8, the script executes before the end
            //of the appendChild execution, so to tie an anonymous define
            //call to the module name (which is stored on the node), hold on
            //to a reference to this node, but clear after the DOM insertion.
            currentlyAddingScript = node;
            if (baseElement) {
                head.insertBefore(node, baseElement);
            } else {
                head.appendChild(node);
            }
            currentlyAddingScript = null;

            return node;
        } else if (isWebWorker) {
            try {
                //In a web worker, use importScripts. This is not a very
                //efficient use of importScripts, importScripts will block until
                //its script is downloaded and evaluated. However, if web workers
                //are in play, the expectation is that a build has been done so
                //that only one script needs to be loaded anyway. This may need
                //to be reevaluated if other use cases become common.
                importScripts(url);

                //Account for anonymous modules
                context.completeLoad(moduleName);
            } catch (e) {
                context.onError(makeError('importscripts',
                                'importScripts failed for ' +
                                    moduleName + ' at ' + url,
                                e,
                                [moduleName]));
            }
        }
    };

    function getInteractiveScript() {
        if (interactiveScript && interactiveScript.readyState === 'interactive') {
            return interactiveScript;
        }

        eachReverse(scripts(), function (script) {
            if (script.readyState === 'interactive') {
                return (interactiveScript = script);
            }
        });
        return interactiveScript;
    }

    //Look for a data-main script attribute, which could also adjust the baseUrl.
    if (isBrowser && !cfg.skipDataMain) {
        //Figure out baseUrl. Get it from the script tag with require.js in it.
        eachReverse(scripts(), function (script) {
            //Set the 'head' where we can append children by
            //using the script's parent.
            if (!head) {
                head = script.parentNode;
            }

            //Look for a data-main attribute to set main script for the page
            //to load. If it is there, the path to data main becomes the
            //baseUrl, if it is not already set.
            dataMain = script.getAttribute('data-main');
            if (dataMain) {
                //Preserve dataMain in case it is a path (i.e. contains '?')
                mainScript = dataMain;

                //Set final baseUrl if there is not already an explicit one.
                if (!cfg.baseUrl) {
                    //Pull off the directory of data-main for use as the
                    //baseUrl.
                    src = mainScript.split('/');
                    mainScript = src.pop();
                    subPath = src.length ? src.join('/')  + '/' : './';

                    cfg.baseUrl = subPath;
                }

                //Strip off any trailing .js since mainScript is now
                //like a module name.
                mainScript = mainScript.replace(jsSuffixRegExp, '');

                //If mainScript is still a path, fall back to dataMain
                if (req.jsExtRegExp.test(mainScript)) {
                    mainScript = dataMain;
                }

                //Put the data-main script in the files to load.
                cfg.deps = cfg.deps ? cfg.deps.concat(mainScript) : [mainScript];

                return true;
            }
        });
    }

    /**
     * The function that handles definitions of modules. Differs from
     * require() in that a string for the module should be the first argument,
     * and the function to execute after dependencies are loaded should
     * return a value to define the module corresponding to the first argument's
     * name.
     */
    define = function (name, deps, callback) {
        var node, context;

        //Allow for anonymous modules
        if (typeof name !== 'string') {
            //Adjust args appropriately
            callback = deps;
            deps = name;
            name = null;
        }

        //This module may not have dependencies
        if (!isArray(deps)) {
            callback = deps;
            deps = null;
        }

        //If no name, and callback is a function, then figure out if it a
        //CommonJS thing with dependencies.
        if (!deps && isFunction(callback)) {
            deps = [];
            //Remove comments from the callback string,
            //look for require calls, and pull them into the dependencies,
            //but only if there are function args.
            if (callback.length) {
                callback
                    .toString()
                    .replace(commentRegExp, '')
                    .replace(cjsRequireRegExp, function (match, dep) {
                        deps.push(dep);
                    });

                //May be a CommonJS thing even without require calls, but still
                //could use exports, and module. Avoid doing exports and module
                //work though if it just needs require.
                //REQUIRES the function to expect the CommonJS variables in the
                //order listed below.
                deps = (callback.length === 1 ? ['require'] : ['require', 'exports', 'module']).concat(deps);
            }
        }

        //If in IE 6-8 and hit an anonymous define() call, do the interactive
        //work.
        if (useInteractive) {
            node = currentlyAddingScript || getInteractiveScript();
            if (node) {
                if (!name) {
                    name = node.getAttribute('data-requiremodule');
                }
                context = contexts[node.getAttribute('data-requirecontext')];
            }
        }

        //Always save off evaluating the def call until the script onload handler.
        //This allows multiple modules to be in a file without prematurely
        //tracing dependencies, and allows for anonymous module support,
        //where the module name is not known until the script onload event
        //occurs. If no context, use the global queue, and get it processed
        //in the onscript load callback.
        if (context) {
            context.defQueue.push([name, deps, callback]);
            context.defQueueMap[name] = true;
        } else {
            globalDefQueue.push([name, deps, callback]);
        }
    };

    define.amd = {
        jQuery: true
    };

    /**
     * Executes the text. Normally just uses eval, but can be modified
     * to use a better, environment-specific call. Only used for transpiling
     * loader plugins, not for plain JS modules.
     * @param {String} text the text to execute/evaluate.
     */
    req.exec = function (text) {
        /*jslint evil: true */
        return eval(text);
    };

    //Set up with config info.
    req(cfg);
}(this));
", @@ -897,6 +857,7 @@ } } }, + "cell_type": "code", "source": [ "# Convert inputs and outputs to subwords\n", "inp_text = to_tokens(encoders[\"inputs\"].encode(inputs))\n", @@ -904,7 +865,7 @@ "\n", "# Run eval to collect attention weights\n", "example = encode_eval(inputs, outputs)\n", - "with tfe.restore_variables_on_create(ckpt_path):\n", + "with tfe.restore_variables_on_create(tf.train.latest_checkpoint(checkpoint_dir)):\n", " translate_model.set_mode(Modes.EVAL)\n", " translate_model(example)\n", "# Get normalized attention weights for each layer\n", @@ -913,8 +874,7 @@ "call_html()\n", "attention.show(inp_text, out_text, enc_atts, dec_atts, encdec_atts)" ], - "cell_type": "code", - "execution_count": 16, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -1376,10 +1336,10 @@ "id": "i7BZuO7T5BB4", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "# Train a custom model on MNIST" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -1392,6 +1352,7 @@ } } }, + "cell_type": "code", "source": [ "# Create your own model\n", "\n", @@ -1411,7 +1372,6 @@ "hparams.hidden_size = 64\n", "model = MySimpleModel(hparams, Modes.TRAIN)" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -1424,11 +1384,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - } - ], "base_uri": "https://localhost:8080/", "height": 34 }, @@ -1445,6 +1400,7 @@ } } }, + "cell_type": "code", "source": [ "# Prepare for the training loop\n", "\n", @@ -1462,8 +1418,7 @@ "\n", "optimizer = tf.train.AdamOptimizer()" ], - "cell_type": "code", - "execution_count": 42, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -1483,11 +1438,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 11 - } - ], "base_uri": "https://localhost:8080/", "height": 204 }, @@ -1504,6 +1454,7 @@ } } }, + "cell_type": "code", "source": [ "# Train\n", "NUM_STEPS = 500\n", @@ -1518,8 +1469,7 @@ " if count >= NUM_STEPS:\n", " break" ], - "cell_type": "code", - "execution_count": 46, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -1549,11 +1499,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 2 - } - ], "base_uri": "https://localhost:8080/", "height": 68 }, @@ -1570,6 +1515,7 @@ } } }, + "cell_type": "code", "source": [ "model.set_mode(Modes.EVAL)\n", "mnist_eval_dataset = mnist_problem.dataset(Modes.EVAL, data_dir)\n", @@ -1597,8 +1543,7 @@ "for name, val in metrics_result().items():\n", " print(\"%s: %.2f\" % (name, val))" ], - "cell_type": "code", - "execution_count": 47, + "execution_count": 0, "outputs": [ { "output_type": "stream", diff --git a/tensor2tensor/rl/envs/simulated_batch_env.py b/tensor2tensor/rl/envs/simulated_batch_env.py index 20fe868f5..9a229b424 100644 --- a/tensor2tensor/rl/envs/simulated_batch_env.py +++ b/tensor2tensor/rl/envs/simulated_batch_env.py @@ -100,7 +100,8 @@ def simulate(self, action): action = tf.concat([action, action], axis=0) inputs = {"inputs": tf.expand_dims(inputs_merged, axis=0), # Add batch. "input_action": tf.expand_dims(action, axis=0)} - model_output = self._model.infer(inputs) + with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): + model_output = self._model.infer(inputs) observ = model_output["targets"] observ = tf.cast(observ[:, 0, :, :, :], tf.float32) reward = model_output["target_reward"][:, 0, 0, 0] - 1 diff --git a/tensor2tensor/rl/envs/tf_atari_wrappers.py b/tensor2tensor/rl/envs/tf_atari_wrappers.py index 9c0af1ae3..83b9a9ae7 100644 --- a/tensor2tensor/rl/envs/tf_atari_wrappers.py +++ b/tensor2tensor/rl/envs/tf_atari_wrappers.py @@ -177,13 +177,11 @@ def __init__(self, batch_env): def simulate(self, action): with tf.name_scope("environment/simulate"): # Do we need this? - observ_copy = self._batch_env.observ.read_value() - with tf.control_dependencies([observ_copy]): - reward, done = self._batch_env.simulate(action) - encoded_image = tf.image.encode_png( - tf.cast(observ_copy[0, ...], tf.uint8)) - with tf.control_dependencies([reward, done]): - enqueue_op = self.speculum.enqueue( - [encoded_image, reward, action, done]) - with tf.control_dependencies([enqueue_op]): - return tf.identity(reward), tf.identity(done) + reward, done = self._batch_env.simulate(action) + encoded_image = tf.image.encode_png( + tf.cast(self._batch_env.observ[0, ...], tf.uint8)) + with tf.control_dependencies([reward, done]): + enqueue_op = self.speculum.enqueue( + [encoded_image, reward, action, done]) + with tf.control_dependencies([enqueue_op]): + return tf.identity(reward), tf.identity(done) diff --git a/tensor2tensor/rl/envs/utils.py b/tensor2tensor/rl/envs/utils.py index cbd4cbb60..26e12eab7 100644 --- a/tensor2tensor/rl/envs/utils.py +++ b/tensor2tensor/rl/envs/utils.py @@ -271,6 +271,7 @@ def _worker(self, constructor, conn): continue if message == self._CLOSE: assert payload is None + env.close() break raise KeyError("Received message of unknown type {}".format(message)) except Exception: # pylint: disable=broad-except diff --git a/tensor2tensor/rl/model_rl_experiment.py b/tensor2tensor/rl/model_rl_experiment.py index 87c07b26e..4fd89022e 100644 --- a/tensor2tensor/rl/model_rl_experiment.py +++ b/tensor2tensor/rl/model_rl_experiment.py @@ -67,7 +67,7 @@ def train(hparams, output_dir): FLAGS.output_dir = output_dir FLAGS.model = hparams.generative_model FLAGS.hparams_set = hparams.generative_model_params - FLAGS.train_steps = hparams.model_train_steps + FLAGS.train_steps = hparams.model_train_steps * (iloop + 2) FLAGS.eval_steps = 10 t2t_trainer.main([]) @@ -108,10 +108,10 @@ def train(hparams, output_dir): def main(_): hparams = tf.contrib.training.HParams( epochs=10, - true_env_generator_num_steps=5000, + true_env_generator_num_steps=10000, generative_model="basic_conv_gen", generative_model_params="basic_conv", - model_train_steps=15000, + model_train_steps=25000, simulated_env_generator_num_steps=300, ppo_epochs_num=200, ppo_epoch_length=300, diff --git a/tensor2tensor/rl/rl_trainer_lib.py b/tensor2tensor/rl/rl_trainer_lib.py index d244bcf2b..4ede6438e 100644 --- a/tensor2tensor/rl/rl_trainer_lib.py +++ b/tensor2tensor/rl/rl_trainer_lib.py @@ -56,11 +56,10 @@ def define_train(hparams, environment_spec, event_dir): "network", functools.partial(policy_lambda, batch_env.action_space, hparams)) - with tf.variable_scope("", reuse=tf.AUTO_REUSE): - memory, collect_summary = collect.define_collect( - policy_factory, batch_env, hparams, eval_phase=False) - ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams) - summary = tf.summary.merge([collect_summary, ppo_summary]) + memory, collect_summary = collect.define_collect( + policy_factory, batch_env, hparams, eval_phase=False) + ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams) + summary = tf.summary.merge([collect_summary, ppo_summary]) with tf.variable_scope("eval", reuse=tf.AUTO_REUSE): eval_env_lambda = env_lambda diff --git a/tensor2tensor/utils/adafactor.py b/tensor2tensor/utils/adafactor.py index 2d328d4ba..6b70fa057 100644 --- a/tensor2tensor/utils/adafactor.py +++ b/tensor2tensor/utils/adafactor.py @@ -34,7 +34,9 @@ class AdafactorOptimizer(tf.train.Optimizer): parameters to maintain the second-moment estimator, instead of AB. This is advantageous on memory-limited systems. In addition, beta1 (momentum) is set to zero by default, saving an additional auxiliary - parameter per weight. + parameter per weight. Variables with >=3 dimensions are treated as + collections of two-dimensional matrices - factorization is over the final + two dimensions. 2. Adafactor incorporates "update-clipping" - a scale-invariant analog of gradient clipping. This adds stability @@ -62,7 +64,7 @@ class AdafactorOptimizer(tf.train.Optimizer): if var is 2-dimensional: v_r <- zeros([num_rows]) v_c <- zeros([num_cols]) - else: + if var is 0-dimensional or 1-dimensional: v <- zeros(shape(var)) ``` @@ -74,10 +76,13 @@ class AdafactorOptimizer(tf.train.Optimizer): v_r <- decay_rate * v_r + (1 - decay_rate) * reduce_mean(grad_squared, 1) v_c <- decay_rate * v_c + (1 - decay_rate) * reduce_mean(grad_squared, 0) v = outer_prod(v_r, v_c) / reduce_mean(v_r) - else: + if var is 0-dimensional or 1-dimensional: v <- decay_rate * v + (1 - decay_rate) * grad_squared ``` + For variables with >=3 dimensions, we factorize the second-moment accumulator + over the final 2 dimensions. See the code for details. + Several parts of this algorithm are configurable from the initializer. @@ -95,8 +100,6 @@ class AdafactorOptimizer(tf.train.Optimizer): factored: whether to factor the second-moment estimator. True means less memory usage. - TODO(noam): we should also apply the 2d logic to the two final dimensions. - of >2d convolutional kernels. """ def __init__(self, @@ -159,7 +162,7 @@ def _should_use_factored_second_moment_estimate(self, shape): Returns: a boolean """ - return self._factored and len(shape) == 2 + return self._factored and len(shape) >= 2 def _create_slots(self, var_list): for var in var_list: @@ -167,8 +170,8 @@ def _create_slots(self, var_list): if self._beta1: self._zeros_slot(var, "m", self._name) if self._should_use_factored_second_moment_estimate(shape): - r_val = tf.zeros([shape[0]], dtype=tf.float32) - c_val = tf.zeros([shape[1]], dtype=tf.float32) + r_val = tf.zeros(shape[:-1], dtype=tf.float32) + c_val = tf.zeros(shape[:-2] + shape[-1:], dtype=tf.float32) self._get_or_make_slot(var, r_val, "vr", self._name) self._get_or_make_slot(var, c_val, "vc", self._name) else: @@ -219,8 +222,8 @@ def _resource_apply_dense(self, grad, var): shape = var.get_shape().as_list() updates = [] if self._should_use_factored_second_moment_estimate(shape): - grad_squared_row_mean = tf.reduce_mean(grad_squared, 1) - grad_squared_col_mean = tf.reduce_mean(grad_squared, 0) + grad_squared_row_mean = tf.reduce_mean(grad_squared, -1) + grad_squared_col_mean = tf.reduce_mean(grad_squared, -2) vr = self.get_slot(var, "vr") new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) vc = self.get_slot(var, "vc") @@ -228,10 +231,10 @@ def _resource_apply_dense(self, grad, var): vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking) vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking) updates = [vr_update, vc_update] - long_term_mean = tf.reduce_mean(new_vr) + long_term_mean = tf.reduce_mean(new_vr, -1, keep_dims=True) r_factor = tf.rsqrt(new_vr / long_term_mean) c_factor = tf.rsqrt(new_vc) - x = grad * tf.expand_dims(r_factor, 1) * tf.expand_dims(c_factor, 0) + x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2) else: v = self.get_slot(var, "v") new_v = decay_rate * v + mixing_rate * grad_squared diff --git a/tensor2tensor/utils/cloud_tpu.py b/tensor2tensor/utils/cloud_tpu.py index ef78458a9..d1e0e875d 100644 --- a/tensor2tensor/utils/cloud_tpu.py +++ b/tensor2tensor/utils/cloud_tpu.py @@ -50,6 +50,7 @@ def __init__(self): def cleanup(self, current_vm_name=None, current_tpu_name=None, skip_confirmation=False): + """Delete old instances and cleanup old trainer and tunnel processes.""" process_pids = os.listdir(self._tmp_dir) for pid in process_pids: try: @@ -72,7 +73,7 @@ def cleanup(self, current_vm_name=None, current_tpu_name=None, del_tpu = False if info["delete_on_done"]: if (info["vm_name"] != current_vm_name and - info["vm_name"] in zip(*list_vm_names_and_ips())[0]): + info["vm_name"] in list(zip(*list_vm_names_and_ips()))[0]): print("Old VM %s found. Delete?" % info["vm_name"]) if skip_confirmation: del_vm = True @@ -80,7 +81,7 @@ def cleanup(self, current_vm_name=None, current_tpu_name=None, if confirm(): del_vm = True if (info["tpu_name"] != current_tpu_name and - info["tpu_name"] in zip(*list_tpu_names_and_ips())[0]): + info["tpu_name"] in list(zip(*list_tpu_names_and_ips()))[0]): print("Old TPU %s found. Delete?" % info["tpu_name"]) if skip_confirmation: del_tpu = True @@ -340,8 +341,8 @@ def create_vm_tpu_pair(vm_name, tpu_name, reuse_if_exists=True, vm_info = list_vm_names_and_ips() tpu_info = list_tpu_names_and_ips() - vm_names = zip(*vm_info)[0] if vm_info else [] - tpu_names = zip(*tpu_info)[0] if tpu_info else [] + vm_names = list(zip(*vm_info))[0] if vm_info else [] + tpu_names = list(zip(*tpu_info))[0] if tpu_info else [] make_vm = False vm_ip = None diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index 887c8a191..2856c76ad 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -51,8 +51,7 @@ def decode_hparams(overrides=""): max_input_size=-1, identity_output=False, num_samples=-1, - delimiter="\n", - force_decode_length=False) + delimiter="\n") hp = hp.parse(overrides) return hp diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py index 3cfa3592d..912f56169 100644 --- a/tensor2tensor/utils/expert_utils.py +++ b/tensor2tensor/utils/expert_utils.py @@ -30,34 +30,12 @@ import six from six.moves import range # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin -import tensorflow as tf - -from tensorflow.python.framework import function - -DEFAULT_DEV_STRING = "existing_device" +from tensor2tensor.layers import common_layers -@function.Defun( - python_grad_func=lambda x, dy: tf.convert_to_tensor(dy), - shape_func=lambda op: [op.inputs[0].get_shape()]) -def convert_gradient_to_tensor(x): - """Identity operation whose gradient is converted to a `Tensor`. - - Currently, the gradient to `tf.concat` is particularly expensive to - compute if dy is an `IndexedSlices` (a lack of GPU implementation - forces the gradient operation onto CPU). This situation occurs when - the output of the `tf.concat` is eventually passed to `tf.gather`. - It is sometimes faster to convert the gradient to a `Tensor`, so as - to get the cheaper gradient for `tf.concat`. To do this, replace - `tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`. - - Args: - x: A `Tensor`. +import tensorflow as tf - Returns: - The input `Tensor`. - """ - return x +DEFAULT_DEV_STRING = "existing_device" def add_scope(scope=None, scope_fn=None): @@ -476,7 +454,7 @@ def noisy_top_k_gating(x, noisy_logits = clean_logits + ( tf.random_normal(tf.shape(clean_logits)) * noise_stddev) logits = noisy_logits - if should_generate_summaries(): + if common_layers.should_generate_summaries(): tf.summary.histogram("noisy_logits", noisy_logits) tf.summary.histogram("noise_stddev", noise_stddev) else: @@ -495,7 +473,7 @@ def noisy_top_k_gating(x, k), 0) else: load = _gates_to_load(gates) - if should_generate_summaries(): + if common_layers.should_generate_summaries(): tf.summary.histogram("importance", tf.reduce_sum(gates, 0)) tf.summary.histogram("load", load) return gates, load @@ -672,7 +650,7 @@ class SparseDispatcher(object): The inputs and outputs are all two-dimensional [batch, depth]. Caller is responsible for collapsing additional dimensions prior to calling this class and reshaping the output to the original shape. - See reshape_like(). + See common_layers.reshape_like(). Example use: @@ -746,7 +724,8 @@ def combine(self, expert_out, multiply_by_gates=True): a `Tensor` with shape `[batch_size, ]`. """ # see comments on convert_gradient_to_tensor - stitched = convert_gradient_to_tensor(tf.concat(expert_out, 0)) + stitched = common_layers.convert_gradient_to_tensor( + tf.concat(expert_out, 0)) if multiply_by_gates: stitched *= tf.expand_dims(self._nonzero_gates, 1) combined = tf.unsorted_segment_sum(stitched, self._batch_index, @@ -821,8 +800,8 @@ def dispatch(self, inp): dispatched = self._dp(lambda a, b: a.dispatch(b), self._dispatchers, inp) ret = self._ep(tf.concat, transpose_list_of_lists(dispatched), 0) if ret[0].dtype == tf.float32: - # see comments on convert_gradient_to_tensor - ret = self._ep(convert_gradient_to_tensor, ret) + # see comments on common_layers.convert_gradient_to_tensor + ret = self._ep(common_layers.convert_gradient_to_tensor, ret) return ret def combine(self, expert_out, multiply_by_gates=True): @@ -846,7 +825,7 @@ def combine(self, expert_out, multiply_by_gates=True): expert_output_parts_t = transpose_list_of_lists(expert_output_parts) def my_combine(dispatcher, parts): return dispatcher.combine( - convert_gradient_to_tensor(tf.concat(parts, 0)), + common_layers.convert_gradient_to_tensor(tf.concat(parts, 0)), multiply_by_gates=multiply_by_gates) return self._dp(my_combine, self._dispatchers, expert_output_parts_t) @@ -904,14 +883,6 @@ def my_fn(x): return my_fn -def reshape_like(a, b): - """Reshapes a to match the shape of b in all but the last dimension.""" - ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0)) - if not tf.contrib.eager.in_eager_mode(): - ret.set_shape(b.get_shape().as_list()[:-1] + a.get_shape().as_list()[-1:]) - return ret - - def flatten_all_but_last(a): """Flatten all dimensions of a except the last.""" ret = tf.reshape(a, [-1, tf.shape(a)[-1]]) @@ -981,7 +952,7 @@ def distributed_moe(data_parallelism, expert_in = dispatcher.dispatch(xs_flat) expert_out = ep(expert_fn, expert_in) ys_flat = dispatcher.combine(expert_out) - ys = dp(reshape_like, ys_flat, xs) + ys = dp(common_layers.reshape_like, ys_flat, xs) # compute some load-balancing losses. load = tf.add_n(load) importance = tf.add_n(dp(tf.reduce_sum, gates, 0)) @@ -1054,7 +1025,7 @@ def local_moe(x, expert_outputs = ep(expert_fn, **expert_kwargs) y_flat = dispatcher.combine(expert_outputs) - y = reshape_like(y_flat, x) + y = common_layers.reshape_like(y_flat, x) importance = tf.reduce_sum(gates, 0) loss = loss_coef * (cv_squared(importance) + cv_squared(load)) @@ -1201,16 +1172,323 @@ def length_coordinate(self): return self._indices -def should_generate_summaries(): - """Is this an appropriate context to generate summaries. +def local_moe_tpu(inputs, + hidden_size, + output_size, + num_experts, + loss_coef=1e-3, + overhead=1.0): + """Local mixture of experts that works well on TPU. + + See https://arxiv.org/abs/1701.06538 + + There are num_experts expert networks, each containing a relu-activated + hidden layer of size hidden_size, followed by an output projection. + + The number of parameters is thus: + num_experts * (input_size * hidden_size + hidden_size * output_size) + + The input is 3d: [batch, length, depth], consisting of the representations + of all positions in a batch of sequences. + + Each position of each sequence is sent to 0-2 experts. The expert + choices and the combination weights are determined by a learned gating + function. + + This function returns a small auxiliary loss that should be added to the + training loss of the model. This loss helps to balance expert usage. + Without the loss, it is very likely that a few experts will be trained and + the rest will starve. + + Several hacks are necessary to get around current TPU limitations: + + - To ensure static shapes, we enforce (by truncation/padding) + that each sequence send the same number of elements to each expert. + + It would make more sense to enforce this equality over the entire batch, + as opposed to on individual sequences. This would allow more freedom + for individual sequences to be unbalanced. Unfortunately, that would + slow down our hacked-up gather-by-matmul implementation. + + TODO(noam): There is no real reason for a single sequence to be the unit + of equal allocation. Reshaping the inputs would allow us to pick a + different unit of equal allocation. + + TODO(noam): Factor this code better. We want to be able to substitute + different code for the experts themselves. We also want to integrate this + gating/dispatching logic into multi-device mixtures-of-experts. + + Args: + inputs: a Tensor with shape [batch, length, depth] + hidden_size: an integer + output_size: an integer + num_experts: an integer + loss_coef: a float scalar + overhead: multiplicative factor of how much spare capacity to assign + + Returns: + outputs: a Tensor with shape [batch, length, output_size] + loss: a scalar + """ + batch, length, input_size = common_layers.shape_list(inputs)[:] + # Each sequence sends expert_capacity positions to each expert. + if isinstance(length, int): + expert_capacity = min( + length, int((length * 2 * overhead) / num_experts)) + else: + expert_capacity = tf.minimum( + length, tf.to_int32( + tf.to_float(length) * 2 * overhead / num_experts)) + expert_capacity_f = tf.to_float(expert_capacity) + + # This is the learned gating function. + gates = tf.nn.softmax( + tf.to_float(common_layers.dense(inputs, num_experts, name="logits"))) + + # Find the top expert for each position. + gate_1, index_1 = common_layers.top_1_tpu(gates) + # [batch, length, num_experts] + mask_1 = tf.one_hot(index_1, num_experts) + # [batch, length, num_experts] + # This is the position within the expert's mini-batch for this sequence + position_in_expert_1 = common_layers.cumsum( + mask_1, axis=1, exclusive=True) * mask_1 + # Remove the elements that don't fit. + mask_1 *= tf.to_float(tf.less(position_in_expert_1, expert_capacity_f)) + # [batch, 1, num_experts] + # How many examples in this sequence go to this expert + mask_1_count = tf.reduce_sum(mask_1, axis=1, keep_dims=True) + # [batch, length] - mostly ones, but zeros where something didn't fit + mask_1_flat = tf.reduce_sum(mask_1, axis=2) + position_in_expert_1 = tf.reduce_sum(position_in_expert_1, axis=2) + # Weight assigned to first expert. + gate_1 *= mask_1_flat + + # Pick a second-place expert for each position. + # We first mask out the experts that we expect to be over-capacity + space_remaining = expert_capacity_f - mask_1_count + use_rate = (mask_1_count + 1.0) / tf.to_float(length) + # At what point in the sequence do we expect the expert to be full. + expected_exhaustion_pos = space_remaining / use_rate + # A Tensor with shape [batch, length, num_experts] representing a boolean + # - whether we expect that the expert will already be full. + expected_exhausted = tf.to_float(tf.greater( + tf.reshape(tf.to_float(tf.range(length)), [1, length, 1]), + expected_exhaustion_pos)) + masked_gates = gates - mask_1 - expected_exhausted + # This section is similar to the section above. + gate_2, index_2 = common_layers.top_1_tpu(masked_gates) + # [batch, length, num_experts] + mask_2 = tf.one_hot(index_2, num_experts) + position_in_expert_2 = ( + common_layers.cumsum(mask_2, axis=1, exclusive=True) + mask_1_count) + position_in_expert_2 *= mask_2 + mask_2 *= tf.to_float(tf.less(position_in_expert_2, expert_capacity_f)) + mask_2_count = tf.reduce_sum(mask_2, axis=1, keep_dims=True) + mask_2_flat = tf.reduce_sum(mask_2, axis=2) + position_in_expert_2 = tf.reduce_sum(position_in_expert_2, axis=2) + gate_2 *= mask_2_flat + + # What fraction didn't fit - show summaries + miss_rate_1 = 1.0 - tf.reduce_sum(mask_1_count) / tf.to_float(batch * length) + miss_rate_2 = 1.0 - tf.reduce_sum(mask_2_count) / tf.to_float(batch * length) + tf.summary.scalar("miss_rate_1", miss_rate_1) + tf.summary.scalar("miss_rate_2", miss_rate_2) + + # renormalize the two gate values to add up to 1 + denom = gate_1 + gate_2 + 1e-9 + gate_1 /= denom + gate_2 /= denom + + # inputs: [batch, length, input_size] + # forward_assignment: [batch, length, num_experts * expert_capacity] + # expert_inputs: [batch, num_experts * expert_capacity, input_size] + + segment_ids_forward_1 = ( + (index_1 * expert_capacity) + + tf.to_int32(position_in_expert_1) + + tf.to_int32(1.0 - mask_1_flat) * (num_experts * expert_capacity)) + + segment_ids_forward_2 = ( + (index_2 * expert_capacity) + + tf.to_int32(position_in_expert_2) + + tf.to_int32(1.0 - mask_2_flat) * (num_experts * expert_capacity)) + + # Gather and scatter are painfully slow on TPU. + # We will use one_hot and matmul instead. + + # [batch, length, num_experts * expert_capacity] + one_hot_1 = tf.one_hot( + segment_ids_forward_1, num_experts * expert_capacity, dtype=inputs.dtype) + one_hot_2 = tf.one_hot( + segment_ids_forward_2, num_experts * expert_capacity, dtype=inputs.dtype) + + forward_assignment = (one_hot_1 + one_hot_2) + + # [batch, num_experts * expert_capacity, input_size] + expert_inputs = tf.matmul(forward_assignment, inputs, transpose_a=True) + + # [batch, num_experts, expert_capacity, input_size] + expert_inputs = tf.reshape( + expert_inputs, [batch, num_experts, expert_capacity, input_size]) + # [num_experts, batch, expert_capacity, input_size] + expert_inputs = tf.transpose(expert_inputs, [1, 0, 2, 3]) + + # [num_experts, batch * expert_capacity, input_size] + expert_inputs = tf.reshape( + expert_inputs, [num_experts, batch * expert_capacity, input_size]) + + # Now feed the expert inputs through the experts. + h = common_layers.batch_dense( + expert_inputs, hidden_size, activation=tf.nn.relu, name="x0") + expert_output = common_layers.batch_dense(h, output_size, name="x1") + expert_output = tf.reshape( + expert_output, [num_experts, batch, expert_capacity, output_size]) + + # [batch, num_experts, expert_capacity, output_size] + expert_output = tf.transpose(expert_output, [1, 0, 2, 3]) + expert_output = tf.reshape( + expert_output, [batch, num_experts * expert_capacity, output_size]) + + # Again, use matmul instead of unsorted_segment_sum. This time, we need + # to multiply by the combination weights gate_1 and gate_2. + + # expert_output: [batch, num_experts * expert_capacity, output_size] + # backward_assigmnent: [batch, length, num_experts * expert_capacity] + # output: [batch, length, output_size] + backward_assigmnent = ( + one_hot_1 * tf.cast(tf.expand_dims(gate_1, 2), inputs.dtype) + + one_hot_2 * tf.cast(tf.expand_dims(gate_2, 2), inputs.dtype)) + output = tf.matmul(backward_assigmnent, expert_output) + + # Compute a loss equal to the coefficient ov variation of the + # total gate value per expert per sequence. + # This loss causes the experts to be used about equally used per sequence. + importance = tf.reduce_sum(gates * (mask_1 + mask_2), 1) + loss = loss_coef * cv_squared(importance) + return output, loss + + +def reduce_by_device(parallelism, data, reduce_fn): + """Reduces data per device. + + This can be useful, for example, if we want to all-reduce n tensors on k