diff --git a/.travis.yml b/.travis.yml index 02e7e0768..2ae1acf65 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,12 +11,15 @@ env: - TF_VERSION="1.5.*" - TF_VERSION="1.6.*" - TF_VERSION="1.7.*" + - TF_VERSION="1.8.*" matrix: exclude: - python: "3.6" env: TF_VERSION="1.5.*" - python: "3.6" env: TF_VERSION="1.6.*" + - python: "3.6" + env: TF_VERSION="1.7.*" before_install: - echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list - curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add - diff --git a/README.md b/README.md index 31b25562f..72c29c8d0 100644 --- a/README.md +++ b/README.md @@ -26,10 +26,10 @@ You can chat with us on ### Quick Start -[This iPython notebook](https://goo.gl/wkHexj) explains T2T and runs in your -browser using a free VM from Google, no installation needed. -Alternatively, here is a one-command version that installs T2T, downloads MNIST, -trains a model and evaluates it: +[This iPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb) +explains T2T and runs in your browser using a free VM from Google, +no installation needed. Alternatively, here is a one-command version that +installs T2T, downloads MNIST, trains a model and evaluates it: ``` pip install tensor2tensor && t2t-trainer \ diff --git a/docs/index.md b/docs/index.md index 9262461c7..2ffbb956d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -19,7 +19,7 @@ research](https://research.googleblog.com/2017/06/accelerating-deep-learning-res ## Basics * [Walkthrough](walkthrough.md): Install and run. -* [IPython notebook](https://goo.gl/wkHexj): Get a hands-on experience. +* [IPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb): Get a hands-on experience. * [Overview](overview.md): How all parts of T2T code are connected. * [New Problem](new_problem.md): Train T2T models on your data. * [New Model](new_model.md): Create your own T2T model. diff --git a/docs/walkthrough.md b/docs/walkthrough.md index 31b25562f..72c29c8d0 100644 --- a/docs/walkthrough.md +++ b/docs/walkthrough.md @@ -26,10 +26,10 @@ You can chat with us on ### Quick Start -[This iPython notebook](https://goo.gl/wkHexj) explains T2T and runs in your -browser using a free VM from Google, no installation needed. -Alternatively, here is a one-command version that installs T2T, downloads MNIST, -trains a model and evaluates it: +[This iPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb) +explains T2T and runs in your browser using a free VM from Google, +no installation needed. Alternatively, here is a one-command version that +installs T2T, downloads MNIST, trains a model and evaluates it: ``` pip install tensor2tensor && t2t-trainer \ diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 7e3b3e008..87ce4df86 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -46,6 +46,7 @@ "tensor2tensor.data_generators.ptb", "tensor2tensor.data_generators.snli", "tensor2tensor.data_generators.squad", + "tensor2tensor.data_generators.subject_verb_agreement", "tensor2tensor.data_generators.translate_encs", "tensor2tensor.data_generators.translate_ende", "tensor2tensor.data_generators.translate_enet", @@ -56,6 +57,7 @@ "tensor2tensor.data_generators.twentybn", "tensor2tensor.data_generators.wiki", "tensor2tensor.data_generators.wikisum.wikisum", + "tensor2tensor.data_generators.wikitext103", "tensor2tensor.data_generators.wsj_parsing", ] diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index e6ab96c04..713ecce69 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -44,7 +44,7 @@ def to_example(dictionary): features = {} for (k, v) in six.iteritems(dictionary): if not v: - raise ValueError("Empty generated field: %s", str((k, v))) + raise ValueError("Empty generated field: %s" % str((k, v))) if isinstance(v[0], six.integer_types): features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v)) elif isinstance(v[0], float): @@ -130,7 +130,8 @@ def outputs_exist(filenames): return out_fname -def generate_files(generator, output_filenames, max_cases=None): +def generate_files(generator, output_filenames, + max_cases=None, cycle_every_n=1): """Generate cases from a generator and save as TFRecord files. Generated cases are transformed to tf.Example protos and saved as TFRecords @@ -141,6 +142,8 @@ def generate_files(generator, output_filenames, max_cases=None): output_filenames: List of output file paths. max_cases: maximum number of cases to get from the generator; if None (default), we use the generator until StopIteration is raised. + cycle_every_n: how many cases from the generator to take before + switching to the next shard; by default set to 1, switch every case. """ if outputs_exist(output_filenames): tf.logging.info("Skipping generator because outputs files exist") @@ -159,7 +162,8 @@ def generate_files(generator, output_filenames, max_cases=None): break example = to_example(case) writers[shard].write(example.SerializeToString()) - shard = (shard + 1) % num_shards + if counter % cycle_every_n == 0: + shard = (shard + 1) % num_shards for writer in writers: writer.close() @@ -341,6 +345,7 @@ def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size, """Generate a vocabulary from the datasets in sources.""" def generate(): + """Generate lines for vocabulary generation.""" tf.logging.info("Generating vocab from: %s", str(sources)) for source in sources: url = source[0] diff --git a/tensor2tensor/data_generators/gym.py b/tensor2tensor/data_generators/gym.py index 3c70e9c49..9ea820233 100644 --- a/tensor2tensor/data_generators/gym.py +++ b/tensor2tensor/data_generators/gym.py @@ -18,10 +18,11 @@ from __future__ import division from __future__ import print_function -from collections import deque - import functools +import os + # Dependency imports + import gym from tensor2tensor.data_generators import problem @@ -62,9 +63,7 @@ def num_target_frames(self): return 1 def eval_metrics(self): - eval_metrics = [ - metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ, - metrics.Metrics.NEG_LOG_PERPLEXITY] + eval_metrics = [metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ] return eval_metrics @property @@ -108,6 +107,10 @@ def num_rewards(self): def num_steps(self): raise NotImplementedError() + @property + def total_number_of_frames(self): + return self.num_steps + @property def min_reward(self): raise NotImplementedError() @@ -126,13 +129,13 @@ def hparams(self, defaults, unused_model_hparams): p.target_space_id = problem.SpaceID.IMAGE def generate_samples(self, data_dir, tmp_dir, unused_dataset_split): - next_obs = self.env.reset() + next_observation = self.env.reset() for _ in range(self.num_steps): - observation = next_obs + observation = next_observation action = self.get_action(observation) - next_obs, reward, done, _ = self.env.step(action) + next_observation, reward, done, _ = self.env.step(action) if done: - next_obs = self.env.reset() + next_observation = self.env.reset() yield {"frame": observation, "action": [action], "done": [done], @@ -184,23 +187,22 @@ class GymDiscreteProblemWithAgent(GymPongRandom5k): def __init__(self, *args, **kwargs): super(GymDiscreteProblemWithAgent, self).__init__(*args, **kwargs) self._env = None - self.history_size = 2 + self.debug_dump_frames_path = "debug_frames_env" # defaults self.environment_spec = lambda: gym.make("PongDeterministic-v4") - self.in_graph_wrappers = [(atari.MaxAndSkipWrapper, {"skip": 4})] + self.in_graph_wrappers = [] self.collect_hparams = rl.atari_base() - self.settable_num_steps = 1000 + self.settable_num_steps = 20000 self.simulated_environment = None - self.warm_up = 70 + self.warm_up = 10 @property def num_steps(self): return self.settable_num_steps def _setup(self): - in_graph_wrappers = [(atari.ShiftRewardWrapper, {"add_value": 2}), - (atari.MemoryWrapper, {})] + self.in_graph_wrappers + in_graph_wrappers = [(atari.MemoryWrapper, {})] + self.in_graph_wrappers env_hparams = tf.contrib.training.HParams( in_graph_wrappers=in_graph_wrappers, simulated_environment=self.simulated_environment) @@ -229,41 +231,41 @@ def _setup(self): self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size() self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue() - self.history_buffer = deque(maxlen=self.history_size+1) def restore_networks(self, sess): if FLAGS.agent_policy_path: model_saver = tf.train.Saver( - tf.global_variables(".*network_parameters.*")) + tf.global_variables(".*network_parameters.*")) model_saver.restore(sess, FLAGS.agent_policy_path) def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split): self._setup() + self.debug_dump_frames_path = os.path.join( + data_dir, self.debug_dump_frames_path) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) self.restore_networks(sess) - + # Actions are shifted by 1 by MemoryWrapper, compensate here. + avilable_data_size = sess.run(self.avilable_data_size_op) + if avilable_data_size < 1: + sess.run(self.collect_trigger_op) pieces_generated = 0 + observ, reward, _, _ = sess.run(self.data_get_op) while pieces_generated < self.num_steps + self.warm_up: avilable_data_size = sess.run(self.avilable_data_size_op) - if avilable_data_size > 0: - observ, reward, action, _ = sess.run(self.data_get_op) - self.history_buffer.append(observ) - - if len(self.history_buffer) == self.history_size + 1: - pieces_generated += 1 - ret_dict = {"image/encoded": [observ], - "image/format": ["png"], - "image/height": [self.frame_height], - "image/width": [self.frame_width], - "action": [int(action)], - "done": [int(False)], - "reward": [int(reward) - self.min_reward]} - if pieces_generated > self.warm_up: - yield ret_dict - else: + if avilable_data_size < 1: sess.run(self.collect_trigger_op) + next_observ, next_reward, action, _ = sess.run(self.data_get_op) + yield {"image/encoded": [observ], + "image/format": ["png"], + "image/height": [self.frame_height], + "image/width": [self.frame_width], + "action": [int(action)], + "done": [int(False)], + "reward": [int(reward) - self.min_reward]} + pieces_generated += 1 + observ, reward = next_observ, next_reward @registry.register_problem @@ -273,7 +275,7 @@ class GymSimulatedDiscreteProblemWithAgent(GymDiscreteProblemWithAgent): def __init__(self, *args, **kwargs): super(GymSimulatedDiscreteProblemWithAgent, self).__init__(*args, **kwargs) self.simulated_environment = True - self.debug_dump_frames_path = "/tmp/t2t_debug_dump_frames" + self.debug_dump_frames_path = "debug_frames_sim" def restore_networks(self, sess): super(GymSimulatedDiscreteProblemWithAgent, self).restore_networks(sess) diff --git a/tensor2tensor/data_generators/imagenet.py b/tensor2tensor/data_generators/imagenet.py index 109d37c5d..06206ce6f 100644 --- a/tensor2tensor/data_generators/imagenet.py +++ b/tensor2tensor/data_generators/imagenet.py @@ -189,7 +189,7 @@ def preprocess_example(self, example, mode, _): @registry.register_problem class ImageImagenet64Gen(ImageImagenet): - """Cifar-10 Tune.""" + """Imagenet 64 from the pixen cnn paper""" @property def train_shards(self): @@ -264,6 +264,33 @@ def preprocess_example(self, example, mode, hparams): return example +@registry.register_problem +class ImageImagenet32Small(ImageImagenet): + """Imagenet small from the pixel cnn paper""" + + @property + def is_small(self): + return False # Modalities like for CIFAR. + + @property + def num_classes(self): + return 1000 + + @property + def train_shards(self): + return 1024 + + @property + def dev_shards(self): + return 10 + + def preprocess_example(self, example, mode, unused_hparams): + example["inputs"].set_shape([_IMAGENET_SMALL_IMAGE_SIZE, + _IMAGENET_SMALL_IMAGE_SIZE, 3]) + example["inputs"] = tf.to_int64(example["inputs"]) + return example + + @registry.register_problem class ImageImagenet64(ImageImagenet32): """Imagenet rescaled to 64x64.""" diff --git a/tensor2tensor/data_generators/squad.py b/tensor2tensor/data_generators/squad.py index 178a3a4f4..e19307242 100644 --- a/tensor2tensor/data_generators/squad.py +++ b/tensor2tensor/data_generators/squad.py @@ -143,5 +143,5 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split): for sample in samples: sample['targets'] = self.generate_targets(sample['targets'], sample['context']) - if not sample['targets']: + if sample['targets']: yield sample diff --git a/tensor2tensor/data_generators/subject_verb_agreement.py b/tensor2tensor/data_generators/subject_verb_agreement.py new file mode 100644 index 000000000..587ff7da6 --- /dev/null +++ b/tensor2tensor/data_generators/subject_verb_agreement.py @@ -0,0 +1,258 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data generators for subject-verb agreement dataset. + +https://arxiv.org/pdf/1611.01368.pdf + +Based on he main paper, predicting verb's number can be done in two setups: +- Language Modeling +- Binary Classification + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from collections import defaultdict +import csv +import gzip +import os +import random + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import text_problems +from tensor2tensor.utils import metrics +from tensor2tensor.utils import registry + +import tensorflow as tf + +_FILE_NAME = 'agr_50_mostcommon_10K' +_TAR = _FILE_NAME + '.tsv.gz' +_URL = 'http://tallinzen.net/media/rnn_agreement/' + _TAR +_LABEL_DICT = {'VBZ': 0, 'VBP': 1} + + +def _build_vocab(examples, example_field, vocab_dir, vocab_name): + """Build a vocabulary from examples. + + Args: + examples: a dict containing all the examples. + example_field: field of example from which the vocabulary is built. + vocab_dir: directory where to save the vocabulary. + vocab_name: vocab file name. + + Returns: + text encoder. + """ + vocab_path = os.path.join(vocab_dir, vocab_name) + if not tf.gfile.Exists(vocab_path): + data = [] + for e in examples: + data.extend(e[example_field].split()) + counter = collections.Counter(data) + count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) + words, _ = list(zip(*count_pairs)) + encoder = text_encoder.TokenTextEncoder(None, vocab_list=words) + encoder.store_to_file(vocab_path) + else: + encoder = text_encoder.TokenTextEncoder(vocab_path) + return encoder + + +def load_examples(tmp_dir, equalize_classes=False): + """Loads exampls from the tsv file. + + Args: + tmp_dir: temp directory. + equalize_classes: if equalize number of examples in the classes. + + Returns: + All examples in the dataset. + + """ + + infile = generator_utils.maybe_download(tmp_dir, _TAR, _URL) + tf.logging.info('Loading examples') + + all_examples = [] + for i, d in enumerate(csv.DictReader(gzip.open(infile), delimiter='\t')): + if i % 100000 == 0: + tf.logging.info('%d examples have been loaded....' % i) + ex = {x: int(y) if y.isdigit() else y for x, y in d.items()} + all_examples.append(ex) + + classes = defaultdict(list) + for ex in all_examples: + classes[ex['verb_pos']].append(ex) + + del all_examples[:] + assert len(classes) == 2 + + c1 = classes.values()[0] + c2 = classes.values()[1] + random.seed(1) + random.shuffle(c1) + random.shuffle(c2) + if equalize_classes: + l = min(len(c1), len(c2)) + all_examples = c1[:l] + c2[:l] + else: + all_examples = c1 + c2 + random.shuffle(all_examples) + + return all_examples + + +@registry.register_problem +class SvaNumberPrediction(text_problems.Text2ClassProblem): + """Subject verb agreement as verb number predicion (binary classification).""" + + @property + def is_generate_per_split(self): + # generate_data will shard the data into TRAIN and EVAL for us. + return False + + @property + def dataset_splits(self): + """Splits of data to produce and number of output shards for each. + + This is the setup of the main paper. 10% train/ 90% eval + + Returns: + A dict containing splits information. + + """ + return [{ + 'split': problem.DatasetSplit.TRAIN, + 'shards': 1, + }, { + 'split': problem.DatasetSplit.EVAL, + 'shards': 9, + }] + + @property + def vocab_type(self): + return text_problems.VocabType.TOKEN + + @property + def num_classes(self): + return 2 + + def class_labels(self, data_dir): + """Class labels.""" + del data_dir + return ['VBZ', 'VBP'] + + def generate_samples(self, data_dir, tmp_dir, dataset_split): + """Generate samples of text and label pairs. + + Each yielded dict will be a single example. The inputs should be raw text. + The label should be an int in [0, self.num_classes). + + Args: + data_dir: final data directory. Typically only used in this method to copy + over user-supplied vocab files (for example, if vocab_type == + VocabType.TOKEN). + tmp_dir: temporary directory that you can use for downloading and scratch. + dataset_split: problem.DatasetSplit, which data split to generate samples + for (for example, training and evaluation). + + Returns: + sample generator. + """ + example_filed = 'sentence' + examples = load_examples(tmp_dir) + _build_vocab(examples, example_filed, data_dir, self.vocab_filename) + + def _generate_samples(): + for example in examples: + index = int(example['verb_index']) - 1 + inputs = example[example_filed].split()[:index] + yield { + 'inputs': ' '.join(inputs), + 'label': _LABEL_DICT[example['verb_pos']] + } + + return _generate_samples() + + def eval_metrics(self): + """Specify the set of evaluation metrics for this problem. + + Returns: + List of evaluation metrics of interest. + """ + return [metrics.Metrics.ACC] + + +@registry.register_problem +class SvaLanguageModeling(text_problems.Text2SelfProblem): + """Subject verb agreement as language modeling task.""" + + @property + def is_generate_per_split(self): + # generate_data will shard the data into TRAIN and EVAL for us. + return False + + @property + def dataset_splits(self): + """Splits of data to produce and number of output shards for each. + + This is the setup of the main paper. 10% train/ 90% eval + + Returns: + A dict containing splits information. + + """ + return [{ + 'split': problem.DatasetSplit.TRAIN, + 'shards': 1, + }, { + 'split': problem.DatasetSplit.EVAL, + 'shards': 9, + }] + + @property + def vocab_type(self): + return text_problems.VocabType.TOKEN + + def generate_samples(self, data_dir, tmp_dir, dataset_split): + """Generates samples. + + Args: + data_dir: data directory + tmp_dir: temp directory + dataset_split: dataset split + + Returns: + sample generator. + + """ + example_filed = 'sentence' + examples = load_examples(tmp_dir) + _build_vocab(examples, example_filed, data_dir, self.vocab_filename) + + def _generate_samples(): + for example in examples: + index = int(example['verb_index']) - 1 + targets = example[example_filed].split()[:index + 1] + yield {'targets': ' '.join(targets)} + + return _generate_samples() diff --git a/tensor2tensor/data_generators/video_utils.py b/tensor2tensor/data_generators/video_utils.py index db965795d..a5f35db75 100644 --- a/tensor2tensor/data_generators/video_utils.py +++ b/tensor2tensor/data_generators/video_utils.py @@ -67,6 +67,11 @@ def frame_width(self): """Width of each frame.""" raise NotImplementedError + @property + def total_number_of_frames(self): + """The total number of frames, needed for sharding.""" + raise NotImplementedError + @property def num_input_frames(self): """Number of frames to batch on one input.""" @@ -188,7 +193,9 @@ def _preprocess(example): preprocessed_dataset = dataset.map(_preprocess) num_frames = self.num_input_frames + self.num_target_frames - # TODO(lukaszkaiser): should jump by a random position at the beginning. + # We jump by a random position at the beginning to add variety. + random_skip = tf.random_uniform([], maxval=num_frames, dtype=tf.int64) + preprocessed_dataset = preprocessed_dataset.skip(random_skip) batch_dataset = preprocessed_dataset.apply( tf.contrib.data.batch_and_drop_remainder(num_frames)) dataset = batch_dataset.map(features_from_batch).shuffle(8) @@ -265,6 +272,8 @@ def generate_encoded_samples_debug(self, data_dir, tmp_dir, dataset_split): for sample in self.generate_encoded_samples( data_dir, tmp_dir, dataset_split): if self.debug_dump_frames_path: + if not tf.gfile.Exists(self.debug_dump_frames_path): + tf.gfile.MkDir(self.debug_dump_frames_path) path = os.path.join(self.debug_dump_frames_path, "frame_%05d.png" % counter) with tf.gfile.Open(path, "wb") as f: @@ -296,7 +305,9 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): else: generator_utils.generate_files( self.generate_encoded_samples_debug( - data_dir, tmp_dir, problem.DatasetSplit.TRAIN), all_paths) + data_dir, tmp_dir, problem.DatasetSplit.TRAIN), + all_paths, + cycle_every_n=self.total_number_of_frames // len(all_paths)) # TODO(lukaszkaiser): remove this version after everything is ported. diff --git a/tensor2tensor/data_generators/wikisum/README.md b/tensor2tensor/data_generators/wikisum/README.md index 4713a8d37..20287a760 100644 --- a/tensor2tensor/data_generators/wikisum/README.md +++ b/tensor2tensor/data_generators/wikisum/README.md @@ -116,10 +116,10 @@ Pricing is taken from * `WikisumCommoncrawl` * `get_references_commoncrawl`: $50 (1k machines, 1 CPU, 2G memory, 1 hour) - * `produce_examples`: $350 (1k machines, 1 CPU, 2G memory, 8 hours) + * `produce_examples`: $25 (1k machines, 1 CPU, 2G memory, 30 minutes) * `WikisumWeb` - * `get_references_web`: $750 (1k machines, 4 CPU, 4G memory, 5 hours) - * `produce_examples`: $350 (1k machines, 1 CPU, 2G memory, 8 hours) + * `get_references_web`: $600 (1k machines, 4 CPU, 4G memory, 4 hours) + * `produce_examples`: $25 (1k machines, 1 CPU, 2G memory, 30 minutes) ## Commands to generate `WikisumCommoncrawl` @@ -130,11 +130,11 @@ pip install tensor2tensor -U --user BUCKET=gs://my-gcs-bucket/wikisum_commoncrawl # Extract references from CommonCrawl -python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ +python -m tensor2tensor.data_generators.wikisum.parallel_launch \ --num_instances=1000 \ --cpu=1 --mem=2 \ - --name=wikisum-refs-cc \ - --log_dir=$BUCKET/refs_logs \ + --name=wikisum-cc-refs \ + --log_dir=$BUCKET/logs \ --setup_command="pip install tensor2tensor tensorflow -U -q --user" \ --command_prefix="python -m tensor2tensor.data_generators.wikisum.get_references_commoncrawl --num_tasks=1000 --out_dir=$BUCKET/wiki_references --task_id" @@ -145,13 +145,13 @@ python -m tensor2tensor.data_generators.wikisum.generate_vocab \ --for_commoncrawl # Produce examples -python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ +python -m tensor2tensor.data_generators.wikisum.parallel_launch \ --num_instances=1000 \ --cpu=1 --mem=2 \ --name=wikisum-cc-produce \ - --log_dir=$BUCKET/produce_logs \ + --log_dir=$BUCKET/logs \ --setup_command="pip install tensor2tensor tensorflow -U -q --user" \ - --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples.py --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --for_commoncrawl --task_id" + --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --for_commoncrawl --task_id" ``` ## Commands to generate `WikisumWeb` @@ -163,11 +163,11 @@ pip install tensor2tensor -U --user BUCKET=gs://my-gcs-bucket/wikisum_web # Fetch references from web -python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ +python -m tensor2tensor.data_generators.wikisum.parallel_launch \ --num_instances=1000 \ --cpu=4 --mem=4 \ - --name=wikisum-refs-web \ - --log_dir=$BUCKET/refs_logs \ + --name=wikisum-web-refs \ + --log_dir=$BUCKET/logs \ --setup_command="pip3 install tensorflow tensor2tensor aiohttp cchardet aiodns bs4 -U -q --user" \ --command_prefix="python3 wikisum/get_references_web.py --out_dir=$BUCKET/wiki_references --shard_id" @@ -177,13 +177,13 @@ python -m tensor2tensor.data_generators.wikisum.generate_vocab \ --refs_dir=$BUCKET/wiki_references # Produce examples -python -m tensor2tensor.data_generators.wikisum.parallel_launch.py \ +python -m tensor2tensor.data_generators.wikisum.parallel_launch \ --num_instances=1000 \ --cpu=1 --mem=2 \ --name=wikisum-web-produce \ - --log_dir=$BUCKET/produce_logs \ + --log_dir=$BUCKET/logs \ --setup_command="pip install tensor2tensor tensorflow -U -q --user" \ - --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples.py --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --task_id" + --command_prefix="python -m tensor2tensor.data_generators.wikisum.produce_examples --out_dir=$BUCKET/data --refs_dir=$BUCKET/wiki_references --num_tasks=1000 --task_id" ``` ## Training diff --git a/tensor2tensor/data_generators/wikisum/generate_vocab.py b/tensor2tensor/data_generators/wikisum/generate_vocab.py index b8e64702f..33fbdbb73 100644 --- a/tensor2tensor/data_generators/wikisum/generate_vocab.py +++ b/tensor2tensor/data_generators/wikisum/generate_vocab.py @@ -43,4 +43,5 @@ def main(_): if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) tf.app.run() diff --git a/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py b/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py index ca1717c73..5198fec5a 100644 --- a/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py +++ b/tensor2tensor/data_generators/wikisum/get_references_web_single_group.py @@ -206,7 +206,7 @@ def write_ref_content(url, ref, f): async def fetch_url(url, session, side_data): text = None try: - async with session.get(url, timeout=30, verify_ssl=False) as response: + async with session.get(url, timeout=10, verify_ssl=False) as response: if response.status == 200: text = await response.text() else: diff --git a/tensor2tensor/data_generators/wikisum/parallel_launch.py b/tensor2tensor/data_generators/wikisum/parallel_launch.py index 9cec88d07..f398218b0 100644 --- a/tensor2tensor/data_generators/wikisum/parallel_launch.py +++ b/tensor2tensor/data_generators/wikisum/parallel_launch.py @@ -44,8 +44,10 @@ from __future__ import division from __future__ import print_function +import contextlib import multiprocessing as mp import os +import socket import subprocess as sp import time @@ -88,12 +90,12 @@ COPY_CODE = "gcloud compute scp --recurse {local_dir} {instance_name}:~/" SSH = "gcloud compute ssh {instance_name} --command" SCREEN = "screen -dmS test bash -c \"{command}\"" -SSH_CHECK = "nc -w 1 -z {ip} 22" DEFAULT_ZONE = "gcloud config get-value compute/zone" LOGS = "> ~/logs-{task_id}.txt 2>&1; gsutil cp ~/logs-{task_id}.txt {bucket}" def remote_run(cmd, instance_name, detach=False, retries=1): + """Run command on GCS instance, optionally detached.""" if detach: cmd = SCREEN.format(command=cmd) args = SSH.format(instance_name=instance_name).split() @@ -112,19 +114,27 @@ def default_zone(): return cloud.shell_output(DEFAULT_ZONE).strip() +@contextlib.contextmanager +def safe_socket(timeout=2): + s = socket.socket() + s.settimeout(timeout) + try: + yield s + finally: + s.close() + + def wait_for_ssh(ip): """Wait for SSH to be available at given IP address.""" - i = 0 - while True: - try: - cloud.shell_run(SSH_CHECK, ip=ip) - break - except sp.CalledProcessError: - if i > 12: # ~2m - return False - time.sleep(10) - i += 1 - return True + for _ in range(12): + with safe_socket() as s: + try: + s.connect((ip, 22)) + return True + except socket.timeout: + pass + time.sleep(10) + return False def create_instance(instance_name, cpu=1, mem=4): @@ -206,18 +216,22 @@ def main(_): assert len(suffixes) == FLAGS.num_instances vm_info = list_vm_names_and_ips() - vm_names = zip(*vm_info)[0] if vm_info else [] + vm_names = list(zip(*vm_info))[0] if vm_info else [] pool = mp.Pool(FLAGS.num_threads) async_results = [] - log_dir = None - if FLAGS.log_dir: - log_dir = os.path.join(FLAGS.log_dir, FLAGS.name) - tf.gfile.MakeDirs(log_dir) - assert log_dir.startswith("gs://") - if not log_dir.endswith("/"): - log_dir += "/" + assert FLAGS.log_dir + log_dir = os.path.join(FLAGS.log_dir, FLAGS.name) + tf.gfile.MakeDirs(log_dir) + assert log_dir.startswith("gs://") + if not log_dir.endswith("/"): + log_dir += "/" + # Write a test file to make sure gcloud GCS APIs are enabled + test_filename = os.path.join(log_dir, "check_write") + with tf.gfile.Open(test_filename, "w") as f: + f.write("testing GCS write") + tf.gfile.Remove(test_filename) instance_ids = list(range(FLAGS.num_instances)) if FLAGS.instance_ids: @@ -242,23 +256,25 @@ def main(_): FLAGS.cpu, FLAGS.mem, code_dir, FLAGS.setup_command) res = pool.apply_async(launch_instance, args) - async_results.append(res) + async_results.append((res, instance_name, i)) failed = [] - for i, res in enumerate(async_results): + for res, instance_name, i in async_results: try: res.get() - except: # pylint: disable=bare-except - failed.append(i) - tf.logging.error("Failed to launch task %d", i) + except Exception as e: # pylint: disable=broad-except + failed.append((instance_name, i)) + tf.logging.error("Failed to launch task %s due to exception %s", + instance_name, str(e)) results = [] if failed: - tf.logging.error("Failed to launch %d jobs. Task ids: %s. " - "Attempting delete in case they are still up.", - len(failed), str(failed)) - for i in failed: - instance_name = "%s-%d" % (FLAGS.name, i) + ids_for_flag = ",".join([str(i) for i in list(zip(*failed))[1]]) + tf.logging.error("Failed to launch %d jobs. Tasks: %s. " + "Attempting delete in case they are still up. Rerun with " + "--instance_ids='%s' to attempt relaunch.", + len(failed), str(failed), ids_for_flag) + for instance_name, _ in failed: res = pool.apply_async(delete_instance, (instance_name,)) results.append(res) diff --git a/tensor2tensor/data_generators/wikisum/wikisum.py b/tensor2tensor/data_generators/wikisum/wikisum.py index 2dac504d6..c0664f328 100644 --- a/tensor2tensor/data_generators/wikisum/wikisum.py +++ b/tensor2tensor/data_generators/wikisum/wikisum.py @@ -278,7 +278,7 @@ def _parse_example(ex_ser): except tf.errors.OutOfRangeError: break - data[ex["url"]] = ex["content"] + data[ex["url"]] = text_encoder.to_unicode(ex["content"]) i += 1 return data @@ -339,11 +339,14 @@ def _parse_example(ex_ser): break sections = [ - WikipediaSection(title=title, text=text) + WikipediaSection(title=text_encoder.to_unicode(title), + text=text_encoder.to_unicode(text)) for title, text in zip(ex["section_titles"], ex["section_texts"]) ] yield WikipediaArticle( - url=ex["url"], title=ex["title"], sections=sections) + url=text_encoder.to_unicode(ex["url"]), + title=text_encoder.to_unicode(ex["title"]), + sections=sections) def _token_counts(text, token_set=None): diff --git a/tensor2tensor/data_generators/wikitext103.py b/tensor2tensor/data_generators/wikitext103.py new file mode 100644 index 000000000..5e1d4f310 --- /dev/null +++ b/tensor2tensor/data_generators/wikitext103.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data generators for wikitext-103. + +Wikitext-103: Long term dependency language modeling dataset +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import zipfile + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import text_problems +from tensor2tensor.utils import registry + +import tensorflow as tf + + +def _build_vocab(filename, vocab_dir, vocab_name): + """Reads a file to build a vocabulary. + + Args: + filename: file to read list of words from. + vocab_dir: directory where to save the vocabulary. + vocab_name: vocab file name. + + Returns: + text encoder. + """ + vocab_path = os.path.join(vocab_dir, vocab_name) + if not tf.gfile.Exists(vocab_path): + with tf.gfile.GFile(filename, "r") as f: + data = f.read().split() + counter = collections.Counter(data) + count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) + words, _ = list(zip(*count_pairs)) + encoder = text_encoder.TokenTextEncoder(None, vocab_list=words) + encoder.store_to_file(vocab_path) + else: + encoder = text_encoder.TokenTextEncoder(vocab_path) + return encoder + + +def _maybe_download_corpus(tmp_dir, vocab_type): + """Download and unpack the corpus. + + Args: + tmp_dir: directory containing dataset. + vocab_type: which vocabulary are we using. + + Returns: + The list of names of files. + """ + if vocab_type == text_problems.VocabType.CHARACTER: + + dataset_url = ("https://s3.amazonaws.com/research.metamind.io/wikitext" + "/wikitext-103-raw-v1.zip") + dir_name = "wikitext-103-raw" + else: + dataset_url = ("https://s3.amazonaws.com/research.metamind.io/wikitext" + "/wikitext-103-v1.zip") + dir_name = "wikitext-103" + + fname = os.path.basename(dataset_url) + compressed_filepath = generator_utils.maybe_download(tmp_dir, fname, + dataset_url) + zip_ref = zipfile.ZipFile(compressed_filepath, "r") + zip_ref.extractall(tmp_dir) + zip_ref.close() + + files = os.path.join(tmp_dir, dir_name, "*") + train_file, valid_file, test_file = None, None, None + for f in tf.gfile.Glob(files): + fname = os.path.basename(f) + if "train" in fname: + train_file = f + elif "valid" in fname: + valid_file = f + elif "test" in fname: + test_file = f + + assert train_file, "Training file not found" + assert valid_file, "Validation file not found" + assert test_file, "Testing file not found" + + return train_file, valid_file, test_file + + +@registry.register_problem +class LanguagemodelWikitext103(text_problems.Text2SelfProblem): + """Wikitext103 dataset token-level.""" + + @property + def dataset_splits(self): + return [{ + "split": problem.DatasetSplit.TRAIN, + "shards": 10, + }, { + "split": problem.DatasetSplit.EVAL, + "shards": 1, + }, { + "split": problem.DatasetSplit.TEST, + "shards": 1, + }] + + @property + def is_generate_per_split(self): + return True + + @property + def vocab_type(self): + return text_problems.VocabType.TOKEN + + def generate_samples(self, data_dir, tmp_dir, dataset_split): + train_file, valid_file, test_file = _maybe_download_corpus( + tmp_dir, self.vocab_type) + + if dataset_split == problem.DatasetSplit.TRAIN: + filepath = train_file + if self.vocab_type == text_problems.VocabType.TOKEN: + _build_vocab(train_file, data_dir, self.vocab_filename) + + elif dataset_split == problem.DatasetSplit.EVAL: + filepath = valid_file + + elif dataset_split == problem.DatasetSplit.TEST: + filepath = test_file + + def _generate_samples(): + with tf.gfile.GFile(filepath, "r") as f: + for line in f: + line = " ".join(line.strip().split()) + if line: + yield {"targets": line} + + return _generate_samples() + + +@registry.register_problem +class LanguagemodelWikitext103Characters(LanguagemodelWikitext103): + """Wikitext-103, character-level.""" + + @property + def vocab_type(self): + return text_problems.VocabType.CHARACTER diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 1545f477f..e94b470de 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -387,7 +387,8 @@ def mse_loss(expected_logits, actual_weights): def get_timing_signal_1d(length, channels, min_timescale=1.0, - max_timescale=1.0e4): + max_timescale=1.0e4, + start_index=0): """Gets a bunch of sinusoids of different frequencies. Each channel of the input Tensor is incremented by a sinusoid of a different @@ -413,11 +414,12 @@ def get_timing_signal_1d(length, different timescales is equal to channels / 2. min_timescale: a float max_timescale: a float + start_index: index of first position Returns: a Tensor of timing signals [1, length, channels] """ - position = tf.to_float(tf.range(length)) + position = tf.to_float(tf.range(length) + start_index) num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / @@ -432,7 +434,10 @@ def get_timing_signal_1d(length, @expert_utils.add_name_scope() -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): +def add_timing_signal_1d(x, + min_timescale=1.0, + max_timescale=1.0e4, + start_index=0): """Adds a bunch of sinusoids of different frequencies to a Tensor. Each channel of the input Tensor is incremented by a sinusoid of a different @@ -456,16 +461,66 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): x: a Tensor with shape [batch, length, channels] min_timescale: a float max_timescale: a float + start_index: index of first position Returns: a Tensor the same shape as x. """ length = common_layers.shape_list(x)[1] channels = common_layers.shape_list(x)[2] - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale, + start_index) return x + signal +@expert_utils.add_name_scope() +def add_layer_timing_signal_learned_1d(x, layer, num_layers): + """Add n-dimensional embedding as the layer (vertical) timing signal. + + Adds embeddings to represent the position of the layer in the tower. + + Args: + x: a tensor with shape [batch, length, depth] + layer: layer num + num_layers: total number of layers + + Returns: + a Tensor the same shape as x. + """ + x_shape = common_layers.shape_list(x) + depth = x_shape[-1] + + shape = [num_layers, 1, 1, depth] + layer_embedding = ( + tf.get_variable( + "layer_embedding", + shape, + initializer=tf.random_normal_initializer(0, depth**-0.5)) * (depth** + 0.5)) + x += layer_embedding[layer, :, :, :] + return x + + +@expert_utils.add_name_scope() +def add_layer_timing_signal_sinusoid_1d(x, layer, num_layers): + """Add sinusoids of different frequencies as layer (vertical) timing signal. + + Args: + x: a Tensor with shape [batch, length, channels] + layer: layer num + num_layers: total number of layers + + Returns: + a Tensor the same shape as x. + """ + + channels = common_layers.shape_list(x)[-1] + signal = get_timing_signal_1d(num_layers, channels) + layer_signal = tf.expand_dims(signal[:, layer, :], axis=1) + + return x + layer_signal + + @expert_utils.add_name_scope() def add_timing_signal_1d_given_position(x, position, @@ -1242,7 +1297,7 @@ def grouped_attention_multihead(query_antecedent, extra_loss *= extra_loss_multiplier # Show a bunch of summaries. - if expert_utils.should_generate_summaries() and make_image_summary: + if common_layers.should_generate_summaries() and make_image_summary: tf.summary.histogram("q_group_size", q_group_size) tf.summary.histogram("m_group_size", m_group_size) tf.summary.scalar("q_loss", q_loss) @@ -1336,7 +1391,7 @@ def dot_product_attention(q, # dropping out the attention links for each of the heads weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) - if expert_utils.should_generate_summaries() and make_image_summary: + if common_layers.should_generate_summaries() and make_image_summary: attention_image_summary(weights, image_shapes) return tf.matmul(weights, v) @@ -1583,7 +1638,7 @@ def dot_product_self_attention_relative_v2(q, # dropping out the attention links for each of the heads weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) - if expert_utils.should_generate_summaries() and make_image_summary: + if common_layers.should_generate_summaries() and make_image_summary: attention_image_summary(weights, image_shapes) ret = tf.matmul(weights, v) # [batch, num_heads, query_length, memory_length] @@ -3796,7 +3851,7 @@ def scaled_dot_product_attention_simple(q, k, v, bias, name=None): if bias is not None: logits += bias weights = tf.nn.softmax(logits, name="attention_weights") - if expert_utils.should_generate_summaries(): + if common_layers.should_generate_summaries(): tf.summary.image( "attention", tf.expand_dims(tf.pow(weights, 0.2), 3), max_outputs=1) return tf.matmul(weights, v) diff --git a/tensor2tensor/layers/common_image_attention.py b/tensor2tensor/layers/common_image_attention.py index 30959398f..ec844b611 100644 --- a/tensor2tensor/layers/common_image_attention.py +++ b/tensor2tensor/layers/common_image_attention.py @@ -284,6 +284,7 @@ def transformer_decoder_layers(inputs, self_attention_bias=None, encoder_decoder_attention_bias=None, attention_type=AttentionType.LOCAL_2D, + losses=None, name="transformer"): """Multi layer transformer.""" x = inputs @@ -335,7 +336,8 @@ def transformer_decoder_layers(inputs, hparams) x = common_layers.layer_postprocess(x, y, hparams) # feed-fwd layers + skip connections - y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) + y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams, + losses=losses) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams) @@ -375,7 +377,7 @@ def transformer_encoder_layers(inputs, return common_layers.layer_preprocess(x, hparams) -def ffn_layer(x, hparams): +def ffn_layer(x, hparams, losses=None): """ffn layer transformer.""" with tf.variable_scope("ffn"): if hparams.ffn_layer == "none": @@ -402,6 +404,23 @@ def ffn_layer(x, hparams): x, hparams.filter_size, hparams.hidden_size, hparams.num_parts, hparams.attention_dropout, hparams.share_kv) y = tf.reshape(y, x_shape) + elif hparams.ffn_layer == "local_moe_tpu": + overhead = (hparams.moe_overhead_train + if hparams.mode == tf.estimator.ModeKeys.TRAIN + else hparams.moe_overhead_eval) + x, x_shape, is_4d = maybe_reshape_4d_to_3d(x) + y, loss = expert_utils.local_moe_tpu( + x, hparams.filter_size // 2, + hparams.hidden_size, + hparams.moe_num_experts, overhead=overhead, + loss_coef=hparams.moe_loss_coef) + if is_4d: + y = tf.reshape(y, x_shape) + if losses is None: + raise ValueError( + "transformer_ffn_layer with type local_moe_tpu must pass in " + "a losses list") + losses.append(loss) else: assert hparams.ffn_layer == "glu_ffn" y = common_layers.gated_linear_unit_layer(x) diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index 1a7c2500b..449e19bda 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -30,18 +30,39 @@ import numpy as np from six.moves import range # pylint: disable=redefined-builtin -from tensor2tensor.utils import expert_utils as eu import tensorflow as tf from tensorflow.python.framework import function from tensorflow.python.framework import ops - # This is a global setting. When turned off, no @function.Defun is used. allow_defun = False +@function.Defun( + python_grad_func=lambda x, dy: tf.convert_to_tensor(dy), + shape_func=lambda op: [op.inputs[0].get_shape()]) +def convert_gradient_to_tensor(x): + """Identity operation whose gradient is converted to a `Tensor`. + + Currently, the gradient to `tf.concat` is particularly expensive to + compute if dy is an `IndexedSlices` (a lack of GPU implementation + forces the gradient operation onto CPU). This situation occurs when + the output of the `tf.concat` is eventually passed to `tf.gather`. + It is sometimes faster to convert the gradient to a `Tensor`, so as + to get the cheaper gradient for `tf.concat`. To do this, replace + `tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`. + + Args: + x: A `Tensor`. + + Returns: + The input `Tensor`. + """ + return x + + def is_on_tpu(): # Support TF versions 1.5+ try: @@ -185,9 +206,10 @@ def convert_rgb_to_real(x): """Conversion of pixel values to real numbers.""" with tf.name_scope("rgb_to_real", values=[x]): x = tf.to_float(x) - # Use the formula (value/128) - 1 to convert each channel value into a - # real number in the range -1 to 1. - x = (x / 128) - 1 + # Use the formula (value/127.5) - 1 to convert each channel value into a + # real number in the range -1 to 1. We use 127.5 instead of 128 because + # the intensities are in the range 0 to 255 + x = (x / 127.5) - 1 return x @@ -230,10 +252,40 @@ def gather(params, indices, dtype=tf.float32): vocab_size = params.get_shape().as_list()[0] indices_flat = tf.reshape(indices, [-1]) out = tf.matmul(tf.one_hot(indices_flat, vocab_size, dtype=dtype), params) - out = eu.reshape_like(out, tf.expand_dims(indices, -1)) + out = reshape_like(out, tf.expand_dims(indices, -1)) return out +# TODO(noam): remove this function after TPUs do cumsum faster. +def cumsum(x, axis=0, exclusive=False): + """TPU hack for tf.cumsum. + + This is equivalent to tf.cumsum and is faster on TPU as of 04/2018 unless + the axis dimension is very large. + + Args: + x: a Tensor + axis: an integer + exclusive: a boolean + Returns: + a Tensor with the same shape as x + """ + if not is_on_tpu(): + return tf.cumsum(x, axis=axis, exclusive=exclusive) + x_shape = shape_list(x) + rank = len(x_shape) + length = x_shape[axis] + my_range = tf.range(length) + comparator = tf.less if exclusive else tf.less_equal + mask = tf.to_float( + comparator(tf.expand_dims(my_range, 1), tf.expand_dims(my_range, 0))) + ret = tf.tensordot(x, mask, axes=[[axis], [0]]) + if axis != rank - 1: + ret = tf.transpose( + ret, list(range(axis)) + [rank - 1] + list(range(axis, rank - 1))) + return ret + + def dropout_no_scaling(x, keep_prob): """Like tf.nn.dropout, but does not scale up. Works on integers also. @@ -267,7 +319,7 @@ def embedding(x, # an indexed-slices to a regular tensor before sending it back to the # parameter server. This avoids excess computation on the parameter server. if not tf.contrib.eager.in_eager_mode(): - embedding_var = eu.convert_gradient_to_tensor(embedding_var) + embedding_var = convert_gradient_to_tensor(embedding_var) x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate) emb_x = gather(embedding_var, x, dtype) if multiplier != 1.0: @@ -1029,7 +1081,7 @@ def simple_attention(target, source, bias=None): if bias is not None: attention += tf.expand_dims(tf.squeeze(bias, axis=[2, 3]), axis=1) attention = tf.nn.softmax(attention) - if eu.should_generate_summaries(): + if should_generate_summaries(): tf.summary.image("attention", tf.expand_dims(attention, 3), max_outputs=5) attended = tf.matmul(attention, source) return tf.reshape(attended, target_shape) @@ -1324,7 +1376,7 @@ def _maybe_transform(t, size, should_transform, name): mask = (1.0 - mask) * -1e9 attention += mask attention = tf.nn.softmax(attention) - if eu.should_generate_summaries(): + if should_generate_summaries(): # Compute a color image summary. image = tf.reshape(attention, [batch, num_heads, target_length, source_length]) @@ -1437,13 +1489,23 @@ def conv_relu_conv(inputs, padding="SAME", nonpadding_mask=None, dropout=0.0, - name=None): + name=None, + cache=None): """Hidden layer with RELU activation followed by linear projection.""" with tf.variable_scope(name, "conv_relu_conv", [inputs]): inputs = maybe_zero_out_padding( inputs, first_kernel_size, nonpadding_mask) + + if cache: + inputs = cache["f"] = tf.concat([cache["f"], inputs], axis=1) + inputs = cache["f"] = inputs[:, -first_kernel_size:, :] + h = tpu_conv1d(inputs, filter_size, first_kernel_size, padding=padding, name="conv1") + + if cache: + h = h[:, -1:, :] + h = tf.nn.relu(h) if dropout != 0.0: h = tf.nn.dropout(h, 1.0 - dropout) @@ -1715,6 +1777,7 @@ def padded_cross_entropy(logits, label_smoothing, weights_fn=weights_nonzero, reduce_sum=True, + cutoff=0.0, gaussian=False): """Compute cross-entropy assuming 0s are padding. @@ -1728,6 +1791,7 @@ def padded_cross_entropy(logits, label_smoothing: a floating point `Scalar`. weights_fn: A function from labels to weights. reduce_sum: a Boolean, whether to sum at the end or not. + cutoff: a float, at which point to have no loss. gaussian: If true, use a Gaussian distribution for label smoothing Returns: @@ -1761,6 +1825,8 @@ def padded_cross_entropy(logits, xent = smoothing_cross_entropy(logits, labels, vocab_size, confidence, gaussian=gaussian) weights = weights_fn(labels) + if cutoff > 0.0: + xent = tf.nn.relu(xent - cutoff) if not reduce_sum: return xent * weights, weights return tf.reduce_sum(xent * weights), tf.reduce_sum(weights) @@ -1894,7 +1960,6 @@ def gated_linear_unit_layer(x, name=None): Returns: x: A tensor """ - with tf.variable_scope( name, default_name="glu_layer", values=[x]): depth = shape_list(x)[-1] @@ -1903,16 +1968,12 @@ def gated_linear_unit_layer(x, name=None): return x * tf.nn.sigmoid(gating_x) -def sru(x, num_layers=2, - activation=None, initial_state=None, name=None, reuse=None): +def sru_with_scan(x, num_layers=2, + activation=None, initial_state=None, name=None, reuse=None): """SRU cell as in https://arxiv.org/abs/1709.02755. - As defined in the paper: - (1) x'_t = W x_t - (2) f_t = sigmoid(Wf x_t + bf) - (3) r_t = sigmoid(Wr x_t + br) - (4) c_t = f_t * c_{t-1} + (1 - f_t) * x'_t - (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t + This implementation uses tf.scan and can incur overhead, see the full SRU + function doc for details and an implementation that is sometimes faster. Args: x: A tensor of shape [batch, ..., channels] ; ... is treated as time. @@ -1962,6 +2023,90 @@ def next_state(cur_state, args_tup): return tf.reshape(x, x_shape) +class CumsumprodCell(object): + """Cumulative sum and product object for use with functional_rnn API.""" + + def __init__(self, initializer): + self._initializer = initializer + + @property + def output_size(self): + return int(shape_list(self._initializer)[-1]) + + def zero_state(self, batch_size, dtype): + dtype = dtype or tf.float32 + return tf.zeros([batch_size, self.output_size], dtype=dtype) + + def __call__(self, inputs_t, state_t): + cur_x_times_one_minus_f, cur_f = tf.split(inputs_t, 2, axis=-1) + state_next = cur_f * state_t + cur_x_times_one_minus_f + outputs_t = state_next + return outputs_t, state_next + + +def sru(x, num_layers=2, + activation=None, initial_state=None, name=None, reuse=None): + """SRU cell as in https://arxiv.org/abs/1709.02755. + + As defined in the paper: + (1) x'_t = W x_t + (2) f_t = sigmoid(Wf x_t + bf) + (3) r_t = sigmoid(Wr x_t + br) + (4) c_t = f_t * c_{t-1} + (1 - f_t) * x'_t + (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t + + This version uses functional ops to be faster on GPUs with TF-1.9+. + + Args: + x: A tensor of shape [batch, ..., channels] ; ... is treated as time. + num_layers: How many SRU layers; default is 2 as results for 1 disappoint. + activation: Optional activation function, try tf.nn.tanh or tf.nn.relu. + initial_state: Optional initial c-state, set to zeros if None. + name: Optional name, "sru" by default. + reuse: Optional reuse. + + Returns: + A tensor of the same shape as x. + + Raises: + ValueError: if num_layers is not positive. + """ + if num_layers < 1: + raise ValueError("Number of layers must be positive: %d" % num_layers) + if is_on_tpu(): # On TPU the XLA does a good job with while. + return sru_with_scan(x, num_layers, activation, initial_state, name, reuse) + try: + from tensorflow.contrib.recurrent.python.ops import functional_rnn # pylint: disable=g-import-not-at-top + except ImportError: + tf.logging.info("functional_rnn not found, using sru_with_scan instead") + return sru_with_scan(x, num_layers, activation, initial_state, name, reuse) + + with tf.variable_scope(name, default_name="sru", values=[x], reuse=reuse): + # We assume x is [batch, ..., channels] and treat all ... as time. + x_shape = shape_list(x) + x = tf.reshape(x, [x_shape[0], -1, x_shape[-1]]) + initial_state = initial_state or tf.zeros([x_shape[0], x_shape[-1]]) + cell = CumsumprodCell(initial_state) + # Calculate SRU on each layer. + for i in range(num_layers): + # The parallel part of the SRU. + x_orig = x + x, f, r = tf.split(tf.layers.dense(x, 3 * x_shape[-1], + name="kernel_%d" % i), 3, axis=-1) + f, r = tf.sigmoid(f), tf.sigmoid(r) + x_times_one_minus_f = x * (1.0 - f) # Compute in parallel for speed. + # Calculate states. + concat = tf.concat([x_times_one_minus_f, f], axis=-1) + c_states, _ = functional_rnn.functional_rnn( + cell, concat, time_major=False) + # Final output. + if activation is not None: + c_states = activation(c_states) + h = c_states * r + (1.0 - r) * x_orig + x = h # Next layer. + return tf.reshape(x, x_shape) + + def linear_set_layer(layer_size, inputs, context=None, @@ -2448,6 +2593,7 @@ def forward_internal(x, f1, f2, scale, bias): @function.Defun(compiled=True) def grad_fn(x, f1, f2, scale, bias, dy): + """Gradient for efficiency.""" with tf.control_dependencies([dy]): num_splits = 4 x_shape = shape_list(x) @@ -2578,130 +2724,6 @@ def reshape_like_all_dims(a, b): return ret -def reduce_by_device(parallelism, data, reduce_fn): - """Reduces data per device. - - This can be useful, for example, if we want to all-reduce n tensors on k100k samples in length and have a width # of 2 or 4. Mono audio has a single channel while stereo has 2. @@ -422,6 +427,7 @@ def bottom(self, inputs): with tf.variable_scope(self.name): # TODO(aidangomez): Will need to sort out a better audio pipeline def xnet_resblock(x, filters, res_relu, name): + """Xception-like block.""" with tf.variable_scope(name): # We only stride along the length dimension to preserve the spectral # bins (which are tiny in dimensionality relative to length) @@ -464,8 +470,15 @@ def bottom(self, inputs): "[batch, time, height, width, channels] but got one " "of shape: %s" % str(inputs_shape)) if not tf.contrib.eager.in_eager_mode(): - tf.summary.image("inputs", tf.cast(inputs[:, -1, :, :, :], tf.uint8), - max_outputs=1) + if inputs.get_shape().as_list()[1] is None: + tf.summary.image( + "inputs_last_frame", tf.cast(inputs[:, -1, :, :, :], tf.uint8), + max_outputs=1) + else: + for k in range(inputs_shape[1]): + tf.summary.image( + "inputs_frame_%d" % k, tf.cast(inputs[:, k, :, :, :], tf.uint8), + max_outputs=1) # Standardize frames. inputs = tf.reshape(inputs, [-1] + inputs_shape[2:]) inputs = common_layers.standardize_images(inputs) @@ -533,9 +546,56 @@ def loss(self, logits, targets): logits, targets, self._model_hparams.label_smoothing, + cutoff=0.001, weights_fn=self.targets_weights_fn) +@registry.register_video_modality("l1") +class VideoModalityL1(VideoModality): + """Video modality that predicts a scalar per channel with an L1 loss.""" + + def top(self, body_output, _): + num_channels = self._model_hparams.problem.num_channels + num_frames = self._model_hparams.problem.num_target_frames + with tf.variable_scope("rgb"): + body_output_shape = common_layers.shape_list(body_output) + res = tf.layers.dense(body_output, num_channels * num_frames, name="cast") + res = tf.reshape(res, body_output_shape[:3] + [num_channels, num_frames]) + res = tf.transpose(res, [0, 4, 1, 2, 3]) # Move frames next to batch. + if not tf.get_variable_scope().reuse: + res_argmax = tf.cast(res[:, -1, :, :, :], tf.uint8) + tf.summary.image("result", res_argmax, max_outputs=1) + return tf.expand_dims(res, axis=-1) # Add an axis like in perplexity. + + @property + def cutoff(self): + return 0.2 + + def internal_loss(self, logits, targets): + return tf.nn.relu(tf.abs(logits - targets) - self.cutoff) + + def loss(self, logits, targets): + """Compute loss numerator and denominator for one shard of output.""" + logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:-1]) + targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:]) + weights = self.targets_weights_fn(targets) + # Shift targets by 0.5 so later just casting to int gives the prediction. + # So for int targets, say 0 and 7, we actually train to predict 0.5 and 7.5. + # Later (in merics or infer) this is cast to int anyway. Also, we have no + # loss beyond self.cutoff = 0.2 as these are already correct predictions. + targets = tf.to_float(targets) + 0.5 + loss = self.internal_loss(logits, targets) + return tf.reduce_sum(loss * weights), tf.reduce_sum(weights) + + +@registry.register_video_modality("l2") +class VideoModalityL2(VideoModalityL1): + """Modality for videos with L2 loss.""" + + def internal_loss(self, logits, targets): + return tf.nn.relu((logits - targets)**2 - self.cutoff * self.cutoff) + + @registry.register_class_label_modality("default") class ClassLabelModality(modality.Modality): """Used for label data.""" diff --git a/tensor2tensor/models/image_transformer.py b/tensor2tensor/models/image_transformer.py index 834a80d34..ed4a25692 100644 --- a/tensor2tensor/models/image_transformer.py +++ b/tensor2tensor/models/image_transformer.py @@ -48,6 +48,8 @@ def body(self, features): hparams.mode == tf.contrib.learn.ModeKeys.INFER): tf.summary.image("targets", targets, max_outputs=1) + # Extra losses list if we want to use moe. + losses = [] # Prepare decoder inputs and bias. decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. @@ -61,9 +63,14 @@ def body(self, features): hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, + losses=losses, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) - return output + + if losses: + return output, {"extra_loss": tf.add_n(losses)} + else: + return output @registry.register_model @@ -171,6 +178,11 @@ def image_transformer_base(): hparams.add_hparam("unconditional", False) # unconditional generation + # These parameters are only used when ffn_layer=="local_moe_tpu" + hparams.add_hparam("moe_overhead_train", 1.0) + hparams.add_hparam("moe_overhead_eval", 2.0) + hparams.moe_num_experts = 8 + hparams.moe_loss_coef = 1e-3 return hparams @@ -678,9 +690,209 @@ def imagetransformer_sep_channels_8l_tpu(): @registry.register_hparams -def imagetransformer_bas8l_8h_big_uncond_dr03_imgnet_tpu(): +def imagetransformer_b10l_4h_big_uncond_dr03_tpu(): hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() update_hparams_for_tpu(hparams) - hparams.batch_size = 1 + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 10 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_dr03_moe_tpu(): + """Moe tpu params.""" + hparams = imagetransformer_b10l_4h_big_uncond_dr03_tpu() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 10 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.ffn_layer = "local_moe_tpu" + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_4h_big_uncond_dr03_lr025_tpu(): + """TPU related small model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 10 + hparams.learning_rate = 0.25 + hparams.learning_rate_warmup_steps = 8000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + # hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_big_uncond_dr03_tpu(): + """TPU 12 layer model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_big_uncond_dr03_lr025_tpu(): + hparams = imagetransformer_b12l_4h_big_uncond_dr03_tpu() + update_hparams_for_tpu(hparams) + hparams.learning_rate = 0.25 + hparams.learning_rate_warmup_steps = 5000 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b256_uncond_dr03_tpu(): + """works very well on 4x4.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.5 + hparams.learning_rate_warmup_steps = 4000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b128_h512_uncond_dr03_tpu(): + """TPU related big model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 6000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b128_h512_uncond_dr03_im(): + """TPU related imagenet model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 6000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_small_uncond_dr03_tpu(): + """TPU related small model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 8 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 4000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b128_uncond_dr03_tpu(): + """TPU config for cifar 10.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 2 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 256 + hparams.filter_size = 2048 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_8h_b256_uncond_dr03_tpu(): + """TPU related 12 layer 8 heads model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 2 hparams.num_heads = 8 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b256_uncond_dr03_lr025_tpu(): + hparams = imagetransformer_b12l_4h_b256_uncond_dr03_tpu() + update_hparams_for_tpu(hparams) + hparams.learning_rate = 0.25 + hparams.learning_rate_warmup_steps = 10000 + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_4h_big_uncond_dr03_lr05_tpu(): + hparams = imagetransformer_b10l_4h_big_uncond_dr03_lr025_tpu() + update_hparams_for_tpu(hparams) + hparams.learning_rate = 0.5 + hparams.learning_rate_warmup_steps = 16000 + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_4h_big_uncond_dr01_tpu(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_b12l_4h_big_uncond_dr03_tpu() + # num_hidden_layers + hparams.num_decoder_layers = 10 + hparams.num_heads = 4 + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + hparams.batch_size = 1 + hparams.layer_prepostprocess_dropout = 0.1 return hparams diff --git a/tensor2tensor/models/research/aligned.py b/tensor2tensor/models/research/aligned.py index ed048a68c..7aa731bd9 100644 --- a/tensor2tensor/models/research/aligned.py +++ b/tensor2tensor/models/research/aligned.py @@ -229,7 +229,7 @@ def _pseudolocal_bias(x): expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes, hparams.hidden_size), dp(expert_utils.flatten_all_but_last, x)) - y = dp(expert_utils.reshape_like, y, x) + y = dp(common_layers.reshape_like, y, x) elif layer_type == "conv": y = dp( common_layers.conv1d, diff --git a/tensor2tensor/models/research/attention_lm_moe.py b/tensor2tensor/models/research/attention_lm_moe.py index 6fd549cbe..dd0163bfb 100644 --- a/tensor2tensor/models/research/attention_lm_moe.py +++ b/tensor2tensor/models/research/attention_lm_moe.py @@ -453,7 +453,7 @@ def restore_pad(x, ref_x, pad_remover, mode): x = tf.squeeze(x, axis=0) if mode != ModeKeys.PREDICT: x = pad_remover.restore(x) - x = expert_utils.reshape_like(x, ref_x) + x = common_layers.reshape_like(x, ref_x) return x diff --git a/tensor2tensor/models/research/autoencoders.py b/tensor2tensor/models/research/autoencoders.py index 4a05af024..aee732e42 100644 --- a/tensor2tensor/models/research/autoencoders.py +++ b/tensor2tensor/models/research/autoencoders.py @@ -54,6 +54,11 @@ def body(self, features): features["targets"], tf.zeros_like(basic_result), hparams.bottleneck_warmup_steps, is_training, max_prob=1.0 - hparams.autoregressive_dropout, broadcast_last=True) + # Sometimes it's useful to look at non-autoregressive evals. + if (hparams.mode == tf.estimator.ModeKeys.EVAL and + hparams.autoregressive_eval_pure_autoencoder): + targets_dropout = tf.zeros_like(basic_result) + # Now combine the basic reconstruction with shifted targets. targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]]) targets_shifted = common_layers.shift_right_3d(targets1d) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) @@ -85,7 +90,7 @@ def body(self, features): raise ValueError("Unsupported autoregressive mode: %s" % hparams.autoregressive_mode) - def infer(self, features=None, *args, **kwargs): + def infer(self, features, *args, **kwargs): """Produce predictions from the model by sampling.""" # Inputs and features preparation needed to handle edge cases. if not features: @@ -109,7 +114,7 @@ def infer(self, features=None, *args, **kwargs): shape = common_layers.shape_list(samples) # Sample again if requested for the autoregressive part. - extra_samples = 0 + extra_samples = self.hparams.autoregressive_decode_steps self.hparams.autoregressive_dropout = 0.2 for i in range(extra_samples): if i == extra_samples - 2: @@ -141,6 +146,17 @@ def infer(self, features=None, *args, **kwargs): class AutoencoderResidual(AutoencoderAutoregressive): """Residual autoencoder.""" + def dropout(self, x): + if self.hparams.dropout <= 0.0: + return x + # For simple dropout just do this: + # return tf.nn.dropout(x, 1.0 - self.hparams.dropout) + is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN + return common_layers.mix( + tf.zeros_like(x), x, + self.hparams.bottleneck_warmup_steps, is_training, + max_prob=self.hparams.dropout, broadcast_last=True) + def encoder(self, x): with tf.variable_scope("encoder"): hparams = self.hparams @@ -156,7 +172,7 @@ def encoder(self, x): for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): x = self.make_even_size(x) - x = tf.nn.dropout(x, 1.0 - hparams.dropout) + x = self.dropout(x) filters = hparams.hidden_size * 2**(i + 1) filters = min(filters, hparams.max_hidden_size) x = tf.layers.conv2d( @@ -189,7 +205,6 @@ def decoder(self, x): residual_conv = tf.layers.separable_conv2d # Up-convolutions. for i in range(hparams.num_hidden_layers): - x = tf.nn.dropout(x, 1.0 - hparams.dropout) j = hparams.num_hidden_layers - i - 1 filters = hparams.hidden_size * 2**j filters = min(filters, hparams.max_hidden_size) @@ -405,6 +420,8 @@ def autoencoder_autoregressive(): hparams.add_hparam("autoregressive_forget_base", False) hparams.add_hparam("autoregressive_mode", "conv3") hparams.add_hparam("autoregressive_dropout", 0.4) + hparams.add_hparam("autoregressive_decode_steps", 0) + hparams.add_hparam("autoregressive_eval_pure_autoencoder", False) return hparams @@ -473,6 +490,7 @@ def autoencoder_residual_discrete_big(): def autoencoder_ordered_discrete(): """Basic autoencoder model.""" hparams = autoencoder_residual_discrete() + hparams.bottleneck_noise = 1.0 return hparams diff --git a/tensor2tensor/models/research/basic_conv_gen.py b/tensor2tensor/models/research/basic_conv_gen.py index 9d4a810bb..77729f0ca 100644 --- a/tensor2tensor/models/research/basic_conv_gen.py +++ b/tensor2tensor/models/research/basic_conv_gen.py @@ -34,33 +34,53 @@ class BasicConvGen(t2t_model.T2TModel): """Basic convolutional next-frame model.""" + def make_even_size(self, x): + """Pad x to be even-sized on axis 1 and 2, but only if necessary.""" + shape = [dim if dim is not None else -1 for dim in x.get_shape().as_list()] + if shape[1] % 2 == 0 and shape[2] % 2 == 0: + return x + if shape[1] % 2 == 0: + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=2) + return x + if shape[2] % 2 == 0: + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=1) + return x + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=1) + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2, axis=2) + return x + def body(self, features): hparams = self.hparams filters = hparams.hidden_size kernel1, kernel2 = (3, 3), (4, 4) - # Pad to make size powers of 2 as needed. - x = features["inputs"] - inputs_shape = common_layers.shape_list(x) - x, _ = common_layers.pad_to_same_length( - x, x, final_length_divisible_by=2**hparams.num_compress_steps, axis=1) - x, _ = common_layers.pad_to_same_length( - x, x, final_length_divisible_by=2**hparams.num_compress_steps, axis=2) + # Embed the inputs. + inputs_shape = common_layers.shape_list(features["inputs"]) + x = tf.layers.dense(features["inputs"], filters, name="inputs_embed") # Down-stride. + layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): + layer_inputs.append(x) + x = self.make_even_size(x) + if i < hparams.filter_double_steps: + filters *= 2 x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) - filters *= 2 # Add embedded action. action = tf.reshape(features["input_action"][:, 1, :], [-1, 1, 1, hparams.hidden_size]) - zeros = tf.zeros(common_layers.shape_list(x)[:-1] + [hparams.hidden_size], - dtype=tf.float32) - x = tf.concat([x, action + zeros], axis=-1) + action_mask = tf.layers.dense(action, filters, name="action_mask") + zeros_mask = tf.zeros(common_layers.shape_list(x)[:-1] + [filters], + dtype=tf.float32) + x *= action_mask + zeros_mask # Run a stack of convolutions. for i in range(hparams.num_hidden_layers): @@ -74,14 +94,18 @@ def body(self, features): x = common_layers.layer_norm(x + y) # Up-convolve. + layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): - filters //= 2 + if i >= hparams.num_compress_steps - hparams.filter_double_steps: + filters //= 2 x = tf.layers.conv2d_transpose( x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") - x = common_layers.layer_norm(x) - x = tf.nn.dropout(x, 1.0 - hparams.dropout) + y = layer_inputs[i] + shape = common_layers.shape_list(y) + x = x[:, :shape[1], :shape[2], :] + x = common_layers.layer_norm(x + y) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] @@ -90,7 +114,7 @@ def body(self, features): reward_pred = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) return {"targets": x, "target_reward": reward_pred} - def infer(self, features=None, *args, **kwargs): + def infer(self, features, *args, **kwargs): """Produce predictions from the model by running it.""" # Inputs and features preparation needed to handle edge cases. if not features: @@ -100,6 +124,16 @@ def infer(self, features=None, *args, **kwargs): inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) + def logits_to_samples(logits): + """Get samples from logits.""" + # If the last dimension is 1 then we're using L1/L2 loss. + if common_layers.shape_list(logits)[-1] == 1: + return tf.to_int32(tf.squeeze(logits, axis=-1)) + # Argmax in TF doesn't handle more than 5 dimensions yet. + logits_shape = common_layers.shape_list(logits) + argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=-1) + return tf.reshape(argmax, logits_shape[:-1]) + # Get predictions. try: num_channels = self._hparams.problem.num_channels @@ -113,15 +147,9 @@ def infer(self, features=None, *args, **kwargs): if isinstance(logits, dict): results = {} for k, v in six.iteritems(logits): - # Argmax in TF doesn't handle more than 5 dimensions yet. - v_shape = common_layers.shape_list(v) - argmax = tf.argmax(tf.reshape(v, [-1, v_shape[-1]]), axis=-1) - results[k] = tf.reshape(argmax, v_shape[:-1]) + results[k] = logits_to_samples(v) else: - # Argmax in TF doesn't handle more than 5 dimensions yet. - logits_shape = common_layers.shape_list(logits) - argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=-1) - results = tf.reshape(argmax, logits_shape[:-1]) + results = logits_to_samples(logits) # Restore inputs to not confuse Estimator in edge cases. if inputs_old is not None: @@ -135,19 +163,20 @@ def infer(self, features=None, *args, **kwargs): def basic_conv(): """Basic 2-frame conv model.""" hparams = common_hparams.basic_params1() - hparams.hidden_size = 64 + hparams.hidden_size = 32 hparams.batch_size = 8 - hparams.num_hidden_layers = 3 - hparams.optimizer = "Adam" - hparams.learning_rate_constant = 0.0002 - hparams.learning_rate_warmup_steps = 500 - hparams.learning_rate_schedule = "constant * linear_warmup" + hparams.num_hidden_layers = 2 + hparams.optimizer = "Adafactor" + hparams.learning_rate_constant = 0.5 + hparams.learning_rate_warmup_steps = 1500 + hparams.learning_rate_schedule = "linear_warmup * constant * rsqrt_decay" hparams.label_smoothing = 0.0 hparams.initializer = "uniform_unit_scaling" hparams.initializer_gain = 1.0 hparams.weight_decay = 0.0 hparams.dropout = 0.2 - hparams.add_hparam("num_compress_steps", 5) + hparams.add_hparam("num_compress_steps", 6) + hparams.add_hparam("filter_double_steps", 5) return hparams @@ -160,58 +189,16 @@ def basic_conv_small(): @registry.register_hparams -def basic_conv_small_per_image_standardization(): - """Small conv model.""" - hparams = common_hparams.basic_params1() - hparams.kernel_sizes = [(3, 3), (5, 5)] - hparams.filter_numbers = [32, 3*256] - hparams.batch_size = 2 - hparams.add_hparam("per_image_standardization", True) +def basic_conv_l1(): + """Basic conv model with L1 modality.""" + hparams = basic_conv() + hparams.target_modality = "video:l1" return hparams -@registry.register_model -class MichiganBasicConvGen(t2t_model.T2TModel): - - def body(self, features): - def deconv2d(cur, i, kernel_size, output_filters, activation=tf.nn.relu): - thicker = common_layers.conv( - cur, - output_filters * 4, - kernel_size, - padding="SAME", - activation=activation, - name="deconv2d" + str(i)) - return tf.depth_to_space(thicker, 2) - - cur_frame = common_layers.standardize_images(features["inputs_0"]) - prev_frame = common_layers.standardize_images(features["inputs_1"]) - - frames = tf.concat([cur_frame, prev_frame], axis=3) - frames = tf.reshape(frames, [-1, 210, 160, 6]) - - h1 = tf.layers.conv2d(frames, filters=64, strides=2, kernel_size=(8, 8), - padding="SAME", activation=tf.nn.relu) - h2 = tf.layers.conv2d(h1, filters=128, strides=2, kernel_size=(6, 6), - padding="SAME", activation=tf.nn.relu) - h3 = tf.layers.conv2d(h2, filters=128, strides=2, kernel_size=(6, 6), - padding="SAME", activation=tf.nn.relu) - h4 = tf.layers.conv2d(h3, filters=128, strides=2, kernel_size=(4, 4), - padding="SAME", activation=tf.nn.relu) - h45 = tf.reshape(h4, [-1, 14 * 10 * 128]) - h5 = tf.layers.dense(h45, 2048, activation=tf.nn.relu) - h6 = tf.layers.dense(h5, 2048, activation=tf.nn.relu) - h7 = tf.layers.dense(h6, 14 * 10 * 128, activation=tf.nn.relu) - h8 = tf.reshape(h7, [-1, 14, 10, 128]) - - h9 = deconv2d(h8, 1, (4, 4), 128, activation=tf.nn.relu) - h9 = h9[:, :27, :, :] - h10 = deconv2d(h9, 2, (6, 6), 128, activation=tf.nn.relu) - h10 = h10[:, :53, :, :] - h11 = deconv2d(h10, 3, (6, 6), 128, activation=tf.nn.relu) - h11 = h11[:, :105, :, :] - h12 = deconv2d(h11, 4, (8, 8), 3 * 256, activation=tf.identity) - - reward = tf.layers.flatten(h12) - - return {"targets": h12, "reward": reward} +@registry.register_hparams +def basic_conv_l2(): + """Basic conv model with L2 modality.""" + hparams = basic_conv() + hparams.target_modality = "video:l2" + return hparams diff --git a/tensor2tensor/models/research/lm_experiments.py b/tensor2tensor/models/research/lm_experiments.py index 10f2b943a..6e5b11094 100644 --- a/tensor2tensor/models/research/lm_experiments.py +++ b/tensor2tensor/models/research/lm_experiments.py @@ -115,3 +115,44 @@ def lmx_relative_nopos(): hparams = lmx_relative() hparams.pos = "none" return hparams + + +@registry.register_hparams +def lmx_moe(): + """Transformer with mixture of experts. 140M Params.""" + hparams = lmx_base() + hparams.ffn_layer = "local_moe_tpu" + return hparams + + +@registry.register_hparams +def lmx_moe_h1k_f4k_x32(): + """Transformer with mixture of experts. 890M Params.""" + hparams = lmx_h1k_f4k() + hparams.ffn_layer = "local_moe_tpu" + hparams.moe_num_experts = 32 + hparams.weight_dtype = "bfloat16" + hparams.batch_size = 8192 + return hparams + + +@registry.register_hparams +def lmx_moe_h1k_f8k_x16(): + """Transformer with mixture of experts. 890M Params.""" + hparams = lmx_h1k_f4k() + hparams.filter_size = 8192 + hparams.ffn_layer = "local_moe_tpu" + hparams.moe_num_experts = 16 + hparams.weight_dtype = "bfloat16" + hparams.batch_size = 8192 + return hparams + + +@registry.register_hparams +def lmx_h1k_f64k(): + """HParams for training languagemodel_lm1b32k_packed. 880M Params.""" + hparams = lmx_base() + hparams.hidden_size = 1024 + hparams.filter_size = 65536 + hparams.batch_size = 2048 + return hparams diff --git a/tensor2tensor/models/research/r_transformer.py b/tensor2tensor/models/research/r_transformer.py index 9af75e8fc..4631a1441 100644 --- a/tensor2tensor/models/research/r_transformer.py +++ b/tensor2tensor/models/research/r_transformer.py @@ -155,6 +155,10 @@ def body(self, features): Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams + if hparams.add_position_timing_signal: + # Turning off addition of positional embedding in the encoder/decoder + # preparation as we do it in the beginning of each step. + hparams.pos = None if self.has_input: inputs = features["inputs"] @@ -203,7 +207,7 @@ def body(self, features): hparams.act_loss_weight * tf.reduce_mean(dec_ponder_times + dec_remainders)) act_loss = enc_act_loss + dec_act_loss - tf.summary.scalar("act_loss", act_loss) + tf.contrib.summary.scalar("act_loss", act_loss) return decoder_output, {"act_loss": act_loss} return decoder_output @@ -278,7 +282,7 @@ def body(self, features): ponder_times, remainders = enc_extra_output act_loss = hparams.act_loss_weight * tf.reduce_mean(ponder_times + remainders) - tf.summary.scalar("act_loss", act_loss) + tf.contrib.summary.scalar("act_loss", act_loss) return encoder_output, {"act_loss": act_loss} return encoder_output @@ -302,6 +306,17 @@ def update_hparams_for_r_transformer(hparams): # Number of steps (which is equivalent to num layer in transformer). hparams.add_hparam("num_rec_steps", hparams.num_hidden_layers) + # Add the positional mebedding at each step(horisontal timing) + hparams.add_hparam("add_position_timing_signal", False) + # Logic of position shifting when using timing signal: + # None, "random", "step" + hparams.add_hparam("position_start_index", None) + + # Add an step embedding at each step (vertical timing) + hparams.add_hparam("add_step_timing_signal", False) + # Either "learned" or "sinusoid" + hparams.add_hparam("step_timing_signal_type", "leaned") + # Default ffn layer is separable convolution. hparams.add_hparam("transformer_ffn_type", "sep") @@ -655,3 +670,183 @@ def r_transformer_lstm_base(): hparams = r_transformer_base() hparams.recurrence_type = "lstm" return hparams + + +@registry.register_hparams +def r_transformer_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.pos = None + hparams.add_position_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_position_random_timing_base(): + hparams = r_transformer_base() + hparams.pos = None + hparams.add_position_timing_signal = True + hparams.position_start_index = "random" + return hparams + + +@registry.register_hparams +def r_transformer_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.pos = None + hparams.add_position_timing_signal = True + hparams.position_start_index = "random" + return hparams + + +@registry.register_hparams +def r_transformer_position_step_timing_tiny(): + hparams = r_transformer_tiny() + hparams.pos = None + hparams.add_position_timing_signal = True + hparams.position_start_index = "step" + return hparams + + +@registry.register_hparams +def r_transformer_step_timing_tiny(): + hparams = r_transformer_tiny() + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_step_sinusoid_timing_tiny(): + hparams = r_transformer_tiny() + hparams.add_step_timing_signal = True + hparams.step_timing_signal_type = "sinusoid" + return hparams + + +@registry.register_hparams +def r_transformer_step_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + return hparams + + +@registry.register_hparams +def r_transformer_act_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + return hparams + + +@registry.register_hparams +def r_transformer_act_position_step_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "step" + return hparams + + +@registry.register_hparams +def r_transformer_act_step_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_step_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_step_sinusoid_position_random_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + hparams.add_step_timing_signal = True + hparams.step_timing_signal_type = "sinusoid" + return hparams + + +@registry.register_hparams +def r_transformer_act_step_sinusoid_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + hparams.step_timing_signal_type = "sinusoid" + return hparams + + +@registry.register_hparams +def r_transformer_act_step_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_step_position_timing_tiny(): + hparams = r_transformer_tiny() + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_step_position_random_timing_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.position_start_index = "random" + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_act_step_position_timing_base(): + hparams = r_transformer_base() + hparams.recurrence_type = "act" + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + return hparams + + +@registry.register_hparams +def r_transformer_step_position_timing_base(): + hparams = r_transformer_base() + hparams.add_position_timing_signal = True + hparams.pos = None + hparams.add_step_timing_signal = True + return hparams diff --git a/tensor2tensor/models/research/r_transformer_util.py b/tensor2tensor/models/research/r_transformer_util.py index 5ac242852..bb6028414 100644 --- a/tensor2tensor/models/research/r_transformer_util.py +++ b/tensor2tensor/models/research/r_transformer_util.py @@ -238,7 +238,10 @@ def get_rt_layer(x, hparams, ffn_unit, attention_unit, pad_remover=None): if hparams.recurrence_type == "basic": rt_initializer = (x, x, x) # (state, input, memory) rt_function = functools.partial( - r_transformer_basic, ffn_unit=ffn_unit, attention_unit=attention_unit) + r_transformer_basic, + hparams=hparams, + ffn_unit=ffn_unit, + attention_unit=attention_unit) elif hparams.recurrence_type == "highway": rt_initializer = (x, x, x) # (state, input, memory) @@ -499,7 +502,7 @@ def transformer_decoder_attention_unit(x, return x -def r_transformer_basic(layer_inputs, unused_step, ffn_unit, attention_unit): +def r_transformer_basic(layer_inputs, step, hparams, ffn_unit, attention_unit): """Basic r_transformer. This is in fact vanilla transformer in which weights are shared between @@ -510,7 +513,8 @@ def r_transformer_basic(layer_inputs, unused_step, ffn_unit, attention_unit): Args: layer_inputs: - state: state - unused_step: indicating number of steps take so far + step: indicating number of steps take so far + hparams: model hyper-parameters ffn_unit: feed-forward unit attention_unit: multi-head attention unit @@ -519,6 +523,7 @@ def r_transformer_basic(layer_inputs, unused_step, ffn_unit, attention_unit): new_state: new state """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) new_state = ffn_unit(attention_unit(state)) @@ -526,7 +531,7 @@ def r_transformer_basic(layer_inputs, unused_step, ffn_unit, attention_unit): def r_transformer_highway(layer_inputs, - unused_step, + step, hparams, ffn_unit, attention_unit, @@ -546,7 +551,7 @@ def r_transformer_highway(layer_inputs, layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. ffn_unit: feed-forward unit attention_unit: multi-head attention unit @@ -560,6 +565,7 @@ def r_transformer_highway(layer_inputs, """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) transformed_state = ffn_unit(attention_unit(state)) state.get_shape().assert_is_compatible_with(state.get_shape()) @@ -614,7 +620,7 @@ def r_transformer_highway(layer_inputs, def r_transformer_skip(layer_inputs, - unused_step, + step, hparams, ffn_unit, attention_unit, @@ -634,7 +640,7 @@ def r_transformer_skip(layer_inputs, layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. ffn_unit: feed-forward unit attention_unit: multi-head attention unit @@ -648,6 +654,7 @@ def r_transformer_skip(layer_inputs, """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) transformed_state = ffn_unit(attention_unit(state)) @@ -744,6 +751,9 @@ def r_transformer_depthwise_attention(layer_inputs, step, hparams, ffn_unit, state_to_be_transformed = tf.reduce_sum( (states_so_far * states_so_far_weights), axis=0) + state_to_be_transformed = step_preprocess(state_to_be_transformed, step, + hparams) + new_state = ffn_unit(attention_unit(state_to_be_transformed)) # add the new state to the memory @@ -753,12 +763,12 @@ def r_transformer_depthwise_attention(layer_inputs, step, hparams, ffn_unit, def r_transformer_rnn(layer_inputs, - unused_step, + step, hparams, ffn_unit, attention_unit, pad_remover=None): - """The RT cell which models recurencey similar to basic RNN cell. + """The RT layer which models recurencey similar to basic RNN cell. It's an R-transformer with an RNN applied over the stats on depth. @@ -766,7 +776,7 @@ def r_transformer_rnn(layer_inputs, layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. ffn_unit: feed-forward unit attention_unit: multi-head attention unit @@ -783,6 +793,7 @@ def r_transformer_rnn(layer_inputs, """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) # TODO(dehghani) keep only the meaningful cases: if hparams.inputs_states_combination == "mh_attention_ffn_add": @@ -824,11 +835,11 @@ def r_transformer_rnn(layer_inputs, def r_transformer_gru(layer_inputs, - unused_step, + step, hparams, attention_unit, pad_remover=None): - """The RT cell which models recurencey similar to GRU cell. + """The RT layer which models recurencey similar to GRU cell. It's an R-transformer with a gru applied over the stats on depth. Based on GRU paper: http://arxiv.org/abs/1406.1078 @@ -837,7 +848,7 @@ def r_transformer_gru(layer_inputs, layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. attention_unit: multi-head attention unit pad_remover: to mask out padding in convolutional layers (efficiency). @@ -850,6 +861,7 @@ def r_transformer_gru(layer_inputs, """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) # TODO(dehghani): do we need preprocess here? state = common_layers.layer_preprocess(state, hparams) @@ -898,11 +910,11 @@ def r_transformer_gru(layer_inputs, def r_transformer_lstm(layer_inputs, - unused_step, + step, hparams, attention_unit, pad_remover=None): - """The RT cell which models recurencey similar to GRU cell. + """The RT layer which models recurencey similar to GRU cell. It's an R-transformer with a gru applied over the stats on depth. based on LSTM paper: https://arxiv.org/pdf/1409.2329.pdf @@ -912,7 +924,7 @@ def r_transformer_lstm(layer_inputs, - state: state - inputs: the original embedded inputs (= inputs to the first step) - memory: memory used in lstm. - unused_step: indicating number of steps take so far + step: indicating number of steps take so far hparams: model hyper-parameters. attention_unit: multi-head attention unit pad_remover: to mask out padding in convolutional layers (efficiency). @@ -924,6 +936,7 @@ def r_transformer_lstm(layer_inputs, memory: contains states from all the previous steps. """ state, inputs, memory = layer_inputs + state = step_preprocess(state, step, hparams) state = common_layers.layer_preprocess(state, hparams) inputs = common_layers.layer_preprocess(inputs, hparams) @@ -1079,6 +1092,7 @@ def rt_function(state, step, halting_probability, remainders, n_updates, new_state: new state """ state_shape = state.get_shape() + state = step_preprocess(state, step, hparams) with tf.variable_scope("sigmoid_activation_for_pondering"): p = common_layers.dense( @@ -1159,13 +1173,13 @@ def should_continue(u0, u1, halting_probability, u2, n_updates, u3): ponder_times = n_updates remainders = remainder - tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + tf.contrib.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) return new_state, (ponder_times, remainders) def r_transformer_act_accumulated(x, hparams, ffn_unit, attention_unit): - """The RTAct cell where the final state is accumulation of all states. + """The RTAct layer where the final state is accumulation of all states. (similar to the main ACT paper: --> check the issue of differentiability) @@ -1230,6 +1244,7 @@ def rt_function(state, step, halting_probability, remainders, n_updates, accumulated_state: accumulated state """ state_shape = state.get_shape() + state = step_preprocess(state, step, hparams) with tf.variable_scope("sigmoid_activation_for_pondering"): p = common_layers.dense( @@ -1309,7 +1324,7 @@ def should_continue(u0, u1, halting_probability, u2, n_updates, u3): ponder_times = n_updates remainders = remainder - tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + tf.contrib.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) return accumulated_state, (ponder_times, remainders) @@ -1366,6 +1381,8 @@ def rt_function(state, step, halting_probability, remainders, n_updates, """ + state = step_preprocess(state, step, hparams) + with tf.variable_scope("sigmoid_activation_for_pondering"): p = common_layers.dense( state, @@ -1449,7 +1466,7 @@ def should_continue(u0, u1, halting_probability, u2, n_updates, u3): ponder_times = n_updates remainders = remainder - tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + tf.contrib.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) return new_state, (ponder_times, remainders) @@ -1520,6 +1537,7 @@ def rt_function(state, step, halting_probability, remainders, n_updates, """ state_shape = state.get_shape() + state = step_preprocess(state, step, hparams) # random as halting probability p = tf.random_uniform(shape=common_layers.shape_list(halting_probability)) @@ -1595,7 +1613,7 @@ def should_continue(u0, u1, halting_probability, u2, n_updates, u3): ponder_times = n_updates remainders = remainder - tf.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) + tf.contrib.summary.scalar("ponder_times", tf.reduce_mean(ponder_times)) return new_state, (ponder_times, remainders) @@ -1749,12 +1767,78 @@ def add_depth_embedding(x): x_shape = common_layers.shape_list(x) depth = x_shape[-1] num_steps = x_shape[0] - shape = [num_steps, 1, 1, x_shape[-1]] + shape = [num_steps, 1, 1, depth] depth_embedding = ( tf.get_variable( "depth_embedding", shape, initializer=tf.random_normal_initializer(0, depth**-0.5)) * (depth** 0.5)) + x += depth_embedding return x + + +def step_preprocess(x, step, hparams): + """Preprocess the input at the beginning of each step. + + Args: + x: input tensor + step: step + hparams: model hyper-parameters + + Returns: + preprocessed input. + + """ + if hparams.add_position_timing_signal: + x = add_position_timing_signal(x, step, hparams) + + if hparams.add_step_timing_signal: + num_steps = ( + hparams.act_max_steps + if hparams.recurrence_type == "act" else hparams.num_rec_steps) + if hparams.step_timing_signal_type == "leaned": + x = common_attention.add_layer_timing_signal_learned_1d( + x, step, num_steps) + elif hparams.step_timing_signal_type == "sinusoid": + x = common_attention.add_layer_timing_signal_sinusoid_1d( + x, step, num_steps) + return x + + +def add_position_timing_signal(x, step, hparams): + """Add n-dimensional embedding as the position (horizontal) timing signal. + + Args: + x: a tensor with shape [batch, length, depth] + step: step + hparams: model hyper parameters + + Returns: + a Tensor the same shape as x. + + """ + + if not hparams.position_start_index: + index = 0 + + elif hparams.position_start_index == "random": + # Shift all positions randomly + # TODO(dehghani): What would be reasonable for max number of shift? + index = tf.random_uniform( + [], maxval=common_layers.shape_list(x)[1], dtype=tf.int32) + + elif hparams.position_start_index == "step": + # Shift positions based on the step + num_steps = ( + hparams.act_max_steps + if hparams.recurrence_type == "act" else hparams.num_rec_steps) + index = tf.cast( + common_layers.shape_list(x)[1] * step / num_steps, dtype=tf.int32) + + # No need for the timing signal in the encoder/decoder input preparation + assert hparams.pos is None + + x_with_timing = common_attention.add_timing_signal_1d(x, start_index=index) + return x_with_timing diff --git a/tensor2tensor/models/research/rl.py b/tensor2tensor/models/research/rl.py index 0c39aa8eb..6feeda60c 100644 --- a/tensor2tensor/models/research/rl.py +++ b/tensor2tensor/models/research/rl.py @@ -182,12 +182,13 @@ def feed_forward_cnn_small_categorical_fun(action_space, config, observations): return NetworkOutput(policy, value, lambda a: a) -def random_policy_fun(action_space, config, observations): - """random policy with categorical output""" - obs_shape = observations.shape.as_list() +def random_policy_fun(action_space, unused_config, observations): + """Random policy with categorical output.""" + obs_shape = observations.shape.as_list() with tf.variable_scope("network_parameters"): value = tf.zeros(obs_shape[:2]) - policy = tf.distributions.Categorical(probs=[[[1. / float(action_space.n)]*action_space.n]*(obs_shape[0]*obs_shape[1])]) - + policy = tf.distributions.Categorical( + probs=[[[1. / float(action_space.n)] * action_space.n + ] * (obs_shape[0] * obs_shape[1])]) return NetworkOutput(policy, value, lambda a: a) diff --git a/tensor2tensor/models/research/super_lm.py b/tensor2tensor/models/research/super_lm.py index 5ffe5e256..adcf5f3ad 100644 --- a/tensor2tensor/models/research/super_lm.py +++ b/tensor2tensor/models/research/super_lm.py @@ -100,7 +100,7 @@ def body(self, features): # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits") - logits = common_layers.all_reduce_ring(logits, mp) + logits = expert_utils.all_reduce_ring(logits, mp) logits = mp(tf.multiply, logits, mp.n ** -0.5) # We now have identical logits on all shards. # Shard 0 gets returned to the estimator. @@ -109,7 +109,7 @@ def body(self, features): logits_shard_0 = tf.expand_dims(logits_shard_0, 3) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. - mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0]) + mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): if mp.n > 1: logits = common_layers.approximate_split(logits, mp.n, 0)[shard] @@ -180,7 +180,7 @@ def _split(t): return tuple(tf.split( t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) - mixed = common_layers.all_reduce_ring(to_mix, mp) + mixed = expert_utils.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n ** -0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": diff --git a/tensor2tensor/models/research/transformer_symshard.py b/tensor2tensor/models/research/transformer_symshard.py index 0acc4251e..fe0741550 100644 --- a/tensor2tensor/models/research/transformer_symshard.py +++ b/tensor2tensor/models/research/transformer_symshard.py @@ -207,10 +207,10 @@ def body(self, features): else: logits = mp( tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) - logits = common_layers.all_reduce_ring(logits, mp) + logits = expert_utils.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. - mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0]) + mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] @@ -282,7 +282,7 @@ def _split(t): return tuple(tf.split( t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) - mixed = common_layers.all_reduce_ring(to_mix, mp) + mixed = expert_utils.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n ** -0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": diff --git a/tensor2tensor/models/research/transformer_vae.py b/tensor2tensor/models/research/transformer_vae.py index 9825d8290..b1cca4792 100644 --- a/tensor2tensor/models/research/transformer_vae.py +++ b/tensor2tensor/models/research/transformer_vae.py @@ -227,8 +227,14 @@ def ae_latent_softmax(latents_pred, latents_discrete, hparams): name="extra_logits") loss = None if latents_discrete is not None: - loss = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=latents_discrete, logits=latents_logits) + if hparams.soft_em: + # latents_discrete is actually one-hot of multinomial samples + assert hparams.num_decode_blocks == 1 + loss = tf.nn.softmax_cross_entropy_with_logits( + labels=latents_discrete, logits=latents_logits) + else: + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=latents_discrete, logits=latents_logits) sample = multinomial_sample( latents_logits, vocab_size, hparams.sampling_temp) return sample, loss diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 95aa91ff1..137829195 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs): super(Transformer, self).__init__(*args, **kwargs) self.attention_weights = dict() # For visualizing attention heads. - def encode(self, inputs, target_space, hparams, features=None): + def encode(self, inputs, target_space, hparams, features=None, losses=None): """Encode transformer inputs. Args: @@ -62,6 +62,7 @@ def encode(self, inputs, target_space, hparams, features=None): hparams: hyperparameters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. + losses: optional list onto which to append extra training losses Returns: Tuple of: @@ -82,7 +83,8 @@ def encode(self, inputs, target_space, hparams, features=None): encoder_output = transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=features_to_nonpadding(features, "inputs"), - save_weights_to=self.attention_weights) + save_weights_to=self.attention_weights, + losses=losses) return encoder_output, encoder_decoder_attention_bias @@ -93,7 +95,8 @@ def decode(self, decoder_self_attention_bias, hparams, cache=None, - nonpadding=None): + nonpadding=None, + losses=None): """Decode Transformer outputs from encoder representation. Args: @@ -109,6 +112,7 @@ def decode(self, cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. nonpadding: optional Tensor with shape [batch_size, decoder_length] + losses: optional list onto which to append extra training losses Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] @@ -124,7 +128,8 @@ def decode(self, hparams, cache=cache, nonpadding=nonpadding, - save_weights_to=self.attention_weights) + save_weights_to=self.attention_weights, + losses=losses) if (common_layers.is_on_tpu() and hparams.mode == tf.estimator.ModeKeys.TRAIN): @@ -150,11 +155,13 @@ def body(self, features): """ hparams = self._hparams + losses = [] + if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( - inputs, target_space, hparams, features=features) + inputs, target_space, hparams, features=features, losses=losses) else: encoder_output, encoder_decoder_attention_bias = (None, None) @@ -171,7 +178,8 @@ def body(self, features): encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, - nonpadding=features_to_nonpadding(features, "targets")) + nonpadding=features_to_nonpadding(features, "targets"), + losses=losses) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: @@ -181,7 +189,11 @@ def body(self, features): hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} - return tf.reshape(decoder_output, targets_shape) + ret = tf.reshape(decoder_output, targets_shape) + if losses: + return ret, {"extra_loss": tf.add_n(losses)} + else: + return ret def _greedy_infer(self, features, decode_length): """Fast version of greedy decoding. @@ -278,10 +290,7 @@ def _fast_decode(self, if target_modality.is_class_modality: decode_length = 1 else: - if 'decode_length' in features: - decode_length = common_layers.shape_list(inputs)[1] + features['decode_length'] - else: - decode_length = common_layers.shape_list(inputs)[1] + decode_length + decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) @@ -318,10 +327,7 @@ def _fast_decode(self, partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] - if 'decode_length' in features: - decode_length = partial_targets_length + features['decode_length'] - else: - decode_length += partial_targets_length + decode_length += partial_targets_length batch_size = partial_targets_shape[0] if hparams.pos == "timing": @@ -395,7 +401,7 @@ def forced_logits(): ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache - force_decode_length=self._decode_hparams.force_decode_length + ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, @@ -406,8 +412,7 @@ def forced_logits(): beam_size=beam_size, top_beams=top_beams, alpha=alpha, - batch_size=batch_size, - force_decode_length=force_decode_length) + batch_size=batch_size) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] @@ -426,8 +431,7 @@ def fast_decode(encoder_output, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID, - batch_size=None, - force_decode_length=False): + batch_size=None): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff @@ -448,7 +452,6 @@ def fast_decode(encoder_output, the preference for longer translations. eos_id: End-of-sequence symbol in beam search. batch_size: an integer scalar - must be passed if there is no input - force_decode_length:force_decode_length: if True, decode will be of length decode_length and will not stop at eos_id. Returns: A dict of decoding results { @@ -473,8 +476,8 @@ def fast_decode(encoder_output, "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), - } - for layer in range(num_layers) + "f": tf.zeros([batch_size, 0, hparams.hidden_size]), + } for layer in range(num_layers) } if encoder_output is not None: @@ -520,10 +523,7 @@ def inner_loop(i, finished, next_id, decoded_ids, cache, log_prob): return i + 1, finished, next_id, decoded_ids, cache, log_prob def is_not_finished(i, finished, *_): - if force_decode_length: - return i < decode_length - else: - return (i < decode_length) & tf.logical_not(tf.reduce_all(finished)) + return (i < decode_length) & tf.logical_not(tf.reduce_all(finished)) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) finished = tf.fill([batch_size], False) @@ -747,7 +747,8 @@ def transformer_encoder(encoder_input, name="encoder", nonpadding=None, save_weights_to=None, - make_image_summary=True): + make_image_summary=True, + losses=None): """A stack of transformer layers. Args: @@ -766,6 +767,7 @@ def transformer_encoder(encoder_input, for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. + losses: optional list onto which to append extra training losses Returns: y: a Tensors @@ -807,7 +809,8 @@ def transformer_encoder(encoder_input, with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover, - conv_padding="SAME", nonpadding_mask=nonpadding) + conv_padding="SAME", nonpadding_mask=nonpadding, + losses=losses) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of @@ -824,7 +827,8 @@ def transformer_decoder(decoder_input, name="decoder", nonpadding=None, save_weights_to=None, - make_image_summary=True): + make_image_summary=True, + losses=None): """A stack of transformer layers. Args: @@ -847,6 +851,7 @@ def transformer_decoder(decoder_input, for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. + losses: optional list onto which to append extra training losses Returns: y: a Tensors @@ -898,8 +903,12 @@ def transformer_decoder(decoder_input, x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( - common_layers.layer_preprocess(x, hparams), hparams, - conv_padding="LEFT", nonpadding_mask=nonpadding) + common_layers.layer_preprocess(x, hparams), + hparams, + conv_padding="LEFT", + nonpadding_mask=nonpadding, + losses=losses, + cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of @@ -911,7 +920,9 @@ def transformer_ffn_layer(x, hparams, pad_remover=None, conv_padding="LEFT", - nonpadding_mask=None): + nonpadding_mask=None, + losses=None, + cache=None): """Feed-forward layer in the transformer. Args: @@ -925,9 +936,16 @@ def transformer_ffn_layer(x, nonpadding_mask: an optional Tensor with shape [batch_size, length]. needed for convolutional layers with "SAME" padding. Contains 1.0 in positions corresponding to nonpadding. + losses: optional list onto which to append extra training losses + cache: dict, containing tensors which are the results of previous + attentions, used for fast decoding. + Returns: a Tensor of shape [batch_size, length, hparams.hidden_size] + + Raises: + ValueError: If losses arg is None, but layer generates extra losses. """ ffn_layer = hparams.ffn_layer relu_dropout_broadcast_dims = ( @@ -959,11 +977,12 @@ def transformer_ffn_layer(x, x, hparams.filter_size, hparams.hidden_size, - first_kernel_size=3, + first_kernel_size=hparams.conv_first_kernel, second_kernel_size=1, padding=conv_padding, nonpadding_mask=nonpadding_mask, - dropout=hparams.relu_dropout) + dropout=hparams.relu_dropout, + cache=cache) elif ffn_layer == "parameter_attention": return common_attention.parameter_attention( x, hparams.parameter_attention_key_channels or hparams.hidden_size, @@ -979,6 +998,23 @@ def transformer_ffn_layer(x, second_kernel_size=(31, 1), padding="LEFT", dropout=hparams.relu_dropout) + elif ffn_layer == "sru": + return common_layers.sru(x) + elif ffn_layer == "local_moe_tpu": + overhead = (hparams.moe_overhead_train + if hparams.mode == tf.estimator.ModeKeys.TRAIN + else hparams.moe_overhead_eval) + ret, loss = expert_utils.local_moe_tpu( + x, hparams.filter_size // 2, + hparams.hidden_size, + hparams.moe_num_experts, overhead=overhead, + loss_coef=hparams.moe_loss_coef) + if losses is None: + raise ValueError( + "transformer_ffn_layer with type local_moe_tpu must pass in " + "a losses list") + losses.append(loss) + return ret else: assert ffn_layer == "none" return x @@ -1033,6 +1069,12 @@ def transformer_base_v1(): hparams.add_hparam("use_pad_remover", True) hparams.add_hparam("self_attention_type", "dot_product") hparams.add_hparam("max_relative_position", 0) + hparams.add_hparam("conv_first_kernel", 3) + # These parameters are only used when ffn_layer=="local_moe_tpu" + hparams.add_hparam("moe_overhead_train", 1.0) + hparams.add_hparam("moe_overhead_eval", 2.0) + hparams.moe_num_experts = 16 + hparams.moe_loss_coef = 1e-3 return hparams diff --git a/tensor2tensor/notebooks/hello_t2t.ipynb b/tensor2tensor/notebooks/hello_t2t.ipynb index 820744500..b7661fcfc 100644 --- a/tensor2tensor/notebooks/hello_t2t.ipynb +++ b/tensor2tensor/notebooks/hello_t2t.ipynb @@ -3,7 +3,7 @@ "nbformat_minor": 0, "metadata": { "colab": { - "name": "T2T with TF Eager", + "name": "Tensor2Tensor Intro", "version": "0.3.2", "views": {}, "default_view": {}, @@ -17,6 +17,18 @@ } }, "cells": [ + { + "metadata": { + "id": "odi2vIMHC3Rm", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Welcome to the [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) Colab\n", + "\n", + "Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and [accelerate ML research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). T2T is actively used and maintained by researchers and engineers within the [Google Brain team](https://research.google.com/teams/brain/) and a community of users. This colab shows you some datasets we have in T2T, how to download and use them, some models we have, how to download pre-trained models and use them, and how to create and train your own models." + ] + }, { "metadata": { "id": "s19ucTii_wYb", @@ -26,9 +38,12 @@ "startup": false, "wait_interval": 0 } - } + }, + "cellView": "form" }, + "cell_type": "code", "source": [ + "#@title\n", "# Copyright 2018 Google LLC.\n", "\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", @@ -43,7 +58,6 @@ "# See the License for the specific language governing permissions and\n", "# limitations under the License." ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -58,11 +72,11 @@ } } }, + "cell_type": "code", "source": [ "# Install deps\n", "!pip install -q -U tensor2tensor tensorflow" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -77,7 +91,9 @@ } } }, + "cell_type": "code", "source": [ + "# Imports we need.\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -111,7 +127,6 @@ "gs_data_dir = \"gs://tensor2tensor-data\"\n", "gs_ckpt_dir = \"gs://tensor2tensor-checkpoints/\"" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -120,10 +135,10 @@ "id": "0a69r1KDiZDe", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "# Download MNIST and inspect it" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -134,11 +149,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - } - ], "base_uri": "https://localhost:8080/", "height": 1241 }, @@ -155,6 +165,7 @@ } } }, + "cell_type": "code", "source": [ "# A Problem is a dataset together with some fixed pre-processing.\n", "# It could be a translation dataset with a specific tokenization,\n", @@ -163,8 +174,7 @@ "# There are many problems available in Tensor2Tensor\n", "problems.available()" ], - "cell_type": "code", - "execution_count": 4, + "execution_count": 0, "outputs": [ { "output_type": "execute_result", @@ -260,11 +270,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 12 - } - ], "base_uri": "https://localhost:8080/", "height": 306 }, @@ -281,6 +286,7 @@ } } }, + "cell_type": "code", "source": [ "# Fetch the MNIST problem\n", "mnist_problem = problems.problem(\"image_mnist\")\n", @@ -288,8 +294,7 @@ "# a standard format ready for training and evaluation.\n", "mnist_problem.generate_data(data_dir, tmp_dir)" ], - "cell_type": "code", - "execution_count": 5, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -325,14 +330,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - }, - { - "item_id": 2 - } - ], "base_uri": "https://localhost:8080/", "height": 381 }, @@ -349,6 +346,7 @@ } } }, + "cell_type": "code", "source": [ "# Now let's see the training MNIST data as Tensors.\n", "mnist_example = tfe.Iterator(mnist_problem.dataset(Modes.TRAIN, data_dir)).next()\n", @@ -358,8 +356,7 @@ "plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap('gray'))\n", "print(\"Label: %d\" % label.numpy())" ], - "cell_type": "code", - "execution_count": 6, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -388,10 +385,10 @@ "id": "gXL7_bVH49Kl", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "# Translate from English to German with a pre-trained model" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -402,11 +399,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 3 - } - ], "base_uri": "https://localhost:8080/", "height": 170 }, @@ -423,6 +415,7 @@ } } }, + "cell_type": "code", "source": [ "# Fetch the problem\n", "ende_problem = problems.problem(\"translate_ende_wmt32k\")\n", @@ -449,8 +442,7 @@ " integers = integers[:integers.index(1)]\n", " return encoders[\"inputs\"].decode(np.squeeze(integers))" ], - "cell_type": "code", - "execution_count": 7, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -480,6 +472,7 @@ } } }, + "cell_type": "code", "source": [ "# # Generate and view the data\n", "# # This cell is commented out because WMT data generation can take hours\n", @@ -504,7 +497,6 @@ "# print(\"Targets, decoded:\")\n", "# print(decode(targets))" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -517,11 +509,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - } - ], "base_uri": "https://localhost:8080/", "height": 408 }, @@ -538,12 +525,12 @@ } } }, + "cell_type": "code", "source": [ "# There are many models available in Tensor2Tensor\n", "registry.list_models()" ], - "cell_type": "code", - "execution_count": 9, + "execution_count": 0, "outputs": [ { "output_type": "execute_result", @@ -592,6 +579,7 @@ } } }, + "cell_type": "code", "source": [ "# Create hparams and the model\n", "model_name = \"transformer\"\n", @@ -604,7 +592,6 @@ "# that will not match the checkpoint.\n", "translate_model = registry.model(model_name)(hparams, Modes.EVAL)" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -617,11 +604,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - } - ], "base_uri": "https://localhost:8080/", "height": 34 }, @@ -638,6 +620,7 @@ } } }, + "cell_type": "code", "source": [ "# Copy the pretrained checkpoint locally\n", "ckpt_name = \"transformer_ende_test\"\n", @@ -646,8 +629,7 @@ "ckpt_path = tf.train.latest_checkpoint(os.path.join(checkpoint_dir, ckpt_name))\n", "ckpt_path" ], - "cell_type": "code", - "execution_count": 11, + "execution_count": 0, "outputs": [ { "output_type": "execute_result", @@ -672,11 +654,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 2 - } - ], "base_uri": "https://localhost:8080/", "height": 68 }, @@ -693,6 +670,7 @@ } } }, + "cell_type": "code", "source": [ "# Restore and translate!\n", "def translate(inputs):\n", @@ -707,8 +685,7 @@ "print(\"Inputs: %s\" % inputs)\n", "print(\"Outputs: %s\" % outputs)" ], - "cell_type": "code", - "execution_count": 13, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -726,10 +703,10 @@ "id": "X3mkIEcbfiTP", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## Attention Viz Utils" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -742,6 +719,7 @@ } } }, + "cell_type": "code", "source": [ "from tensor2tensor.visualization import attention\n", "from tensor2tensor.data_generators import text_encoder\n", @@ -796,7 +774,6 @@ " tokens.append(subtokenizer._subtoken_id_to_subtoken_string(_id))\n", " return tokens" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -811,6 +788,7 @@ } } }, + "cell_type": "code", "source": [ "def call_html():\n", " import IPython\n", @@ -827,7 +805,6 @@ " \n", " '''))" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -836,10 +813,10 @@ "id": "T7UJzFf6fmhp", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## Display Attention" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -850,23 +827,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - }, - { - "item_id": 2 - }, - { - "item_id": 3 - }, - { - "item_id": 4 - }, - { - "item_id": 5 - } - ], "resources": { "http://localhost:8080/static/components/requirejs/require.js": { "data": "LyoqIHZpbTogZXQ6dHM9NDpzdz00OnN0cz00CiAqIEBsaWNlbnNlIFJlcXVpcmVKUyAyLjEuMjIgQ29weXJpZ2h0IChjKSAyMDEwLTIwMTUsIFRoZSBEb2pvIEZvdW5kYXRpb24gQWxsIFJpZ2h0cyBSZXNlcnZlZC4KICogQXZhaWxhYmxlIHZpYSB0aGUgTUlUIG9yIG5ldyBCU0QgbGljZW5zZS4KICogc2VlOiBodHRwOi8vZ2l0aHViLmNvbS9qcmJ1cmtlL3JlcXVpcmVqcyBmb3IgZGV0YWlscwogKi8KLy9Ob3QgdXNpbmcgc3RyaWN0OiB1bmV2ZW4gc3RyaWN0IHN1cHBvcnQgaW4gYnJvd3NlcnMsICMzOTIsIGFuZCBjYXVzZXMKLy9wcm9ibGVtcyB3aXRoIHJlcXVpcmVqcy5leGVjKCkvdHJhbnNwaWxlciBwbHVnaW5zIHRoYXQgbWF5IG5vdCBiZSBzdHJpY3QuCi8qanNsaW50IHJlZ2V4cDogdHJ1ZSwgbm9tZW46IHRydWUsIHNsb3BweTogdHJ1ZSAqLwovKmdsb2JhbCB3aW5kb3csIG5hdmlnYXRvciwgZG9jdW1lbnQsIGltcG9ydFNjcmlwdHMsIHNldFRpbWVvdXQsIG9wZXJhICovCgp2YXIgcmVxdWlyZWpzLCByZXF1aXJlLCBkZWZpbmU7CihmdW5jdGlvbiAoZ2xvYmFsKSB7CiAgICB2YXIgcmVxLCBzLCBoZWFkLCBiYXNlRWxlbWVudCwgZGF0YU1haW4sIHNyYywKICAgICAgICBpbnRlcmFjdGl2ZVNjcmlwdCwgY3VycmVudGx5QWRkaW5nU2NyaXB0LCBtYWluU2NyaXB0LCBzdWJQYXRoLAogICAgICAgIHZlcnNpb24gPSAnMi4xLjIyJywKICAgICAgICBjb21tZW50UmVnRXhwID0gLyhcL1wqKFtcc1xTXSo/KVwqXC98KFteOl18XilcL1wvKC4qKSQpL21nLAogICAgICAgIGNqc1JlcXVpcmVSZWdFeHAgPSAvW14uXVxzKnJlcXVpcmVccypcKFxzKlsiJ10oW14nIlxzXSspWyInXVxzKlwpL2csCiAgICAgICAganNTdWZmaXhSZWdFeHAgPSAvXC5qcyQvLAogICAgICAgIGN1cnJEaXJSZWdFeHAgPSAvXlwuXC8vLAogICAgICAgIG9wID0gT2JqZWN0LnByb3RvdHlwZSwKICAgICAgICBvc3RyaW5nID0gb3AudG9TdHJpbmcsCiAgICAgICAgaGFzT3duID0gb3AuaGFzT3duUHJvcGVydHksCiAgICAgICAgYXAgPSBBcnJheS5wcm90b3R5cGUsCiAgICAgICAgaXNCcm93c2VyID0gISEodHlwZW9mIHdpbmRvdyAhPT0gJ3VuZGVmaW5lZCcgJiYgdHlwZW9mIG5hdmlnYXRvciAhPT0gJ3VuZGVmaW5lZCcgJiYgd2luZG93LmRvY3VtZW50KSwKICAgICAgICBpc1dlYldvcmtlciA9ICFpc0Jyb3dzZXIgJiYgdHlwZW9mIGltcG9ydFNjcmlwdHMgIT09ICd1bmRlZmluZWQnLAogICAgICAgIC8vUFMzIGluZGljYXRlcyBsb2FkZWQgYW5kIGNvbXBsZXRlLCBidXQgbmVlZCB0byB3YWl0IGZvciBjb21wbGV0ZQogICAgICAgIC8vc3BlY2lmaWNhbGx5LiBTZXF1ZW5jZSBpcyAnbG9hZGluZycsICdsb2FkZWQnLCBleGVjdXRpb24sCiAgICAgICAgLy8gdGhlbiAnY29tcGxldGUnLiBUaGUgVUEgY2hlY2sgaXMgdW5mb3J0dW5hdGUsIGJ1dCBub3Qgc3VyZSBob3cKICAgICAgICAvL3RvIGZlYXR1cmUgdGVzdCB3L28gY2F1c2luZyBwZXJmIGlzc3Vlcy4KICAgICAgICByZWFkeVJlZ0V4cCA9IGlzQnJvd3NlciAmJiBuYXZpZ2F0b3IucGxhdGZvcm0gPT09ICdQTEFZU1RBVElPTiAzJyA/CiAgICAgICAgICAgICAgICAgICAgICAvXmNvbXBsZXRlJC8gOiAvXihjb21wbGV0ZXxsb2FkZWQpJC8sCiAgICAgICAgZGVmQ29udGV4dE5hbWUgPSAnXycsCiAgICAgICAgLy9PaCB0aGUgdHJhZ2VkeSwgZGV0ZWN0aW5nIG9wZXJhLiBTZWUgdGhlIHVzYWdlIG9mIGlzT3BlcmEgZm9yIHJlYXNvbi4KICAgICAgICBpc09wZXJhID0gdHlwZW9mIG9wZXJhICE9PSAndW5kZWZpbmVkJyAmJiBvcGVyYS50b1N0cmluZygpID09PSAnW29iamVjdCBPcGVyYV0nLAogICAgICAgIGNvbnRleHRzID0ge30sCiAgICAgICAgY2ZnID0ge30sCiAgICAgICAgZ2xvYmFsRGVmUXVldWUgPSBbXSwKICAgICAgICB1c2VJbnRlcmFjdGl2ZSA9IGZhbHNlOwoKICAgIGZ1bmN0aW9uIGlzRnVuY3Rpb24oaXQpIHsKICAgICAgICByZXR1cm4gb3N0cmluZy5jYWxsKGl0KSA9PT0gJ1tvYmplY3QgRnVuY3Rpb25dJzsKICAgIH0KCiAgICBmdW5jdGlvbiBpc0FycmF5KGl0KSB7CiAgICAgICAgcmV0dXJuIG9zdHJpbmcuY2FsbChpdCkgPT09ICdbb2JqZWN0IEFycmF5XSc7CiAgICB9CgogICAgLyoqCiAgICAgKiBIZWxwZXIgZnVuY3Rpb24gZm9yIGl0ZXJhdGluZyBvdmVyIGFuIGFycmF5LiBJZiB0aGUgZnVuYyByZXR1cm5zCiAgICAgKiBhIHRydWUgdmFsdWUsIGl0IHdpbGwgYnJlYWsgb3V0IG9mIHRoZSBsb29wLgogICAgICovCiAgICBmdW5jdGlvbiBlYWNoKGFyeSwgZnVuYykgewogICAgICAgIGlmIChhcnkpIHsKICAgICAgICAgICAgdmFyIGk7CiAgICAgICAgICAgIGZvciAoaSA9IDA7IGkgPCBhcnkubGVuZ3RoOyBpICs9IDEpIHsKICAgICAgICAgICAgICAgIGlmIChhcnlbaV0gJiYgZnVuYyhhcnlbaV0sIGksIGFyeSkpIHsKICAgICAgICAgICAgICAgICAgICBicmVhazsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQogICAgICAgIH0KICAgIH0KCiAgICAvKioKICAgICAqIEhlbHBlciBmdW5jdGlvbiBmb3IgaXRlcmF0aW5nIG92ZXIgYW4gYXJyYXkgYmFja3dhcmRzLiBJZiB0aGUgZnVuYwogICAgICogcmV0dXJucyBhIHRydWUgdmFsdWUsIGl0IHdpbGwgYnJlYWsgb3V0IG9mIHRoZSBsb29wLgogICAgICovCiAgICBmdW5jdGlvbiBlYWNoUmV2ZXJzZShhcnksIGZ1bmMpIHsKICAgICAgICBpZiAoYXJ5KSB7CiAgICAgICAgICAgIHZhciBpOwogICAgICAgICAgICBmb3IgKGkgPSBhcnkubGVuZ3RoIC0gMTsgaSA+IC0xOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgIGlmIChhcnlbaV0gJiYgZnVuYyhhcnlbaV0sIGksIGFyeSkpIHsKICAgICAgICAgICAgICAgICAgICBicmVhazsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQogICAgICAgIH0KICAgIH0KCiAgICBmdW5jdGlvbiBoYXNQcm9wKG9iaiwgcHJvcCkgewogICAgICAgIHJldHVybiBoYXNPd24uY2FsbChvYmosIHByb3ApOwogICAgfQoKICAgIGZ1bmN0aW9uIGdldE93bihvYmosIHByb3ApIHsKICAgICAgICByZXR1cm4gaGFzUHJvcChvYmosIHByb3ApICYmIG9ialtwcm9wXTsKICAgIH0KCiAgICAvKioKICAgICAqIEN5Y2xlcyBvdmVyIHByb3BlcnRpZXMgaW4gYW4gb2JqZWN0IGFuZCBjYWxscyBhIGZ1bmN0aW9uIGZvciBlYWNoCiAgICAgKiBwcm9wZXJ0eSB2YWx1ZS4gSWYgdGhlIGZ1bmN0aW9uIHJldHVybnMgYSB0cnV0aHkgdmFsdWUsIHRoZW4gdGhlCiAgICAgKiBpdGVyYXRpb24gaXMgc3RvcHBlZC4KICAgICAqLwogICAgZnVuY3Rpb24gZWFjaFByb3Aob2JqLCBmdW5jKSB7CiAgICAgICAgdmFyIHByb3A7CiAgICAgICAgZm9yIChwcm9wIGluIG9iaikgewogICAgICAgICAgICBpZiAoaGFzUHJvcChvYmosIHByb3ApKSB7CiAgICAgICAgICAgICAgICBpZiAoZnVuYyhvYmpbcHJvcF0sIHByb3ApKSB7CiAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9CiAgICB9CgogICAgLyoqCiAgICAgKiBTaW1wbGUgZnVuY3Rpb24gdG8gbWl4IGluIHByb3BlcnRpZXMgZnJvbSBzb3VyY2UgaW50byB0YXJnZXQsCiAgICAgKiBidXQgb25seSBpZiB0YXJnZXQgZG9lcyBub3QgYWxyZWFkeSBoYXZlIGEgcHJvcGVydHkgb2YgdGhlIHNhbWUgbmFtZS4KICAgICAqLwogICAgZnVuY3Rpb24gbWl4aW4odGFyZ2V0LCBzb3VyY2UsIGZvcmNlLCBkZWVwU3RyaW5nTWl4aW4pIHsKICAgICAgICBpZiAoc291cmNlKSB7CiAgICAgICAgICAgIGVhY2hQcm9wKHNvdXJjZSwgZnVuY3Rpb24gKHZhbHVlLCBwcm9wKSB7CiAgICAgICAgICAgICAgICBpZiAoZm9yY2UgfHwgIWhhc1Byb3AodGFyZ2V0LCBwcm9wKSkgewogICAgICAgICAgICAgICAgICAgIGlmIChkZWVwU3RyaW5nTWl4aW4gJiYgdHlwZW9mIHZhbHVlID09PSAnb2JqZWN0JyAmJiB2YWx1ZSAmJgogICAgICAgICAgICAgICAgICAgICAgICAhaXNBcnJheSh2YWx1ZSkgJiYgIWlzRnVuY3Rpb24odmFsdWUpICYmCiAgICAgICAgICAgICAgICAgICAgICAgICEodmFsdWUgaW5zdGFuY2VvZiBSZWdFeHApKSB7CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoIXRhcmdldFtwcm9wXSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgdGFyZ2V0W3Byb3BdID0ge307CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgbWl4aW4odGFyZ2V0W3Byb3BdLCB2YWx1ZSwgZm9yY2UsIGRlZXBTdHJpbmdNaXhpbik7CiAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgdGFyZ2V0W3Byb3BdID0gdmFsdWU7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9KTsKICAgICAgICB9CiAgICAgICAgcmV0dXJuIHRhcmdldDsKICAgIH0KCiAgICAvL1NpbWlsYXIgdG8gRnVuY3Rpb24ucHJvdG90eXBlLmJpbmQsIGJ1dCB0aGUgJ3RoaXMnIG9iamVjdCBpcyBzcGVjaWZpZWQKICAgIC8vZmlyc3QsIHNpbmNlIGl0IGlzIGVhc2llciB0byByZWFkL2ZpZ3VyZSBvdXQgd2hhdCAndGhpcycgd2lsbCBiZS4KICAgIGZ1bmN0aW9uIGJpbmQob2JqLCBmbikgewogICAgICAgIHJldHVybiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgIHJldHVybiBmbi5hcHBseShvYmosIGFyZ3VtZW50cyk7CiAgICAgICAgfTsKICAgIH0KCiAgICBmdW5jdGlvbiBzY3JpcHRzKCkgewogICAgICAgIHJldHVybiBkb2N1bWVudC5nZXRFbGVtZW50c0J5VGFnTmFtZSgnc2NyaXB0Jyk7CiAgICB9CgogICAgZnVuY3Rpb24gZGVmYXVsdE9uRXJyb3IoZXJyKSB7CiAgICAgICAgdGhyb3cgZXJyOwogICAgfQoKICAgIC8vQWxsb3cgZ2V0dGluZyBhIGdsb2JhbCB0aGF0IGlzIGV4cHJlc3NlZCBpbgogICAgLy9kb3Qgbm90YXRpb24sIGxpa2UgJ2EuYi5jJy4KICAgIGZ1bmN0aW9uIGdldEdsb2JhbCh2YWx1ZSkgewogICAgICAgIGlmICghdmFsdWUpIHsKICAgICAgICAgICAgcmV0dXJuIHZhbHVlOwogICAgICAgIH0KICAgICAgICB2YXIgZyA9IGdsb2JhbDsKICAgICAgICBlYWNoKHZhbHVlLnNwbGl0KCcuJyksIGZ1bmN0aW9uIChwYXJ0KSB7CiAgICAgICAgICAgIGcgPSBnW3BhcnRdOwogICAgICAgIH0pOwogICAgICAgIHJldHVybiBnOwogICAgfQoKICAgIC8qKgogICAgICogQ29uc3RydWN0cyBhbiBlcnJvciB3aXRoIGEgcG9pbnRlciB0byBhbiBVUkwgd2l0aCBtb3JlIGluZm9ybWF0aW9uLgogICAgICogQHBhcmFtIHtTdHJpbmd9IGlkIHRoZSBlcnJvciBJRCB0aGF0IG1hcHMgdG8gYW4gSUQgb24gYSB3ZWIgcGFnZS4KICAgICAqIEBwYXJhbSB7U3RyaW5nfSBtZXNzYWdlIGh1bWFuIHJlYWRhYmxlIGVycm9yLgogICAgICogQHBhcmFtIHtFcnJvcn0gW2Vycl0gdGhlIG9yaWdpbmFsIGVycm9yLCBpZiB0aGVyZSBpcyBvbmUuCiAgICAgKgogICAgICogQHJldHVybnMge0Vycm9yfQogICAgICovCiAgICBmdW5jdGlvbiBtYWtlRXJyb3IoaWQsIG1zZywgZXJyLCByZXF1aXJlTW9kdWxlcykgewogICAgICAgIHZhciBlID0gbmV3IEVycm9yKG1zZyArICdcbmh0dHA6Ly9yZXF1aXJlanMub3JnL2RvY3MvZXJyb3JzLmh0bWwjJyArIGlkKTsKICAgICAgICBlLnJlcXVpcmVUeXBlID0gaWQ7CiAgICAgICAgZS5yZXF1aXJlTW9kdWxlcyA9IHJlcXVpcmVNb2R1bGVzOwogICAgICAgIGlmIChlcnIpIHsKICAgICAgICAgICAgZS5vcmlnaW5hbEVycm9yID0gZXJyOwogICAgICAgIH0KICAgICAgICByZXR1cm4gZTsKICAgIH0KCiAgICBpZiAodHlwZW9mIGRlZmluZSAhPT0gJ3VuZGVmaW5lZCcpIHsKICAgICAgICAvL0lmIGEgZGVmaW5lIGlzIGFscmVhZHkgaW4gcGxheSB2aWEgYW5vdGhlciBBTUQgbG9hZGVyLAogICAgICAgIC8vZG8gbm90IG92ZXJ3cml0ZS4KICAgICAgICByZXR1cm47CiAgICB9CgogICAgaWYgKHR5cGVvZiByZXF1aXJlanMgIT09ICd1bmRlZmluZWQnKSB7CiAgICAgICAgaWYgKGlzRnVuY3Rpb24ocmVxdWlyZWpzKSkgewogICAgICAgICAgICAvL0RvIG5vdCBvdmVyd3JpdGUgYW4gZXhpc3RpbmcgcmVxdWlyZWpzIGluc3RhbmNlLgogICAgICAgICAgICByZXR1cm47CiAgICAgICAgfQogICAgICAgIGNmZyA9IHJlcXVpcmVqczsKICAgICAgICByZXF1aXJlanMgPSB1bmRlZmluZWQ7CiAgICB9CgogICAgLy9BbGxvdyBmb3IgYSByZXF1aXJlIGNvbmZpZyBvYmplY3QKICAgIGlmICh0eXBlb2YgcmVxdWlyZSAhPT0gJ3VuZGVmaW5lZCcgJiYgIWlzRnVuY3Rpb24ocmVxdWlyZSkpIHsKICAgICAgICAvL2Fzc3VtZSBpdCBpcyBhIGNvbmZpZyBvYmplY3QuCiAgICAgICAgY2ZnID0gcmVxdWlyZTsKICAgICAgICByZXF1aXJlID0gdW5kZWZpbmVkOwogICAgfQoKICAgIGZ1bmN0aW9uIG5ld0NvbnRleHQoY29udGV4dE5hbWUpIHsKICAgICAgICB2YXIgaW5DaGVja0xvYWRlZCwgTW9kdWxlLCBjb250ZXh0LCBoYW5kbGVycywKICAgICAgICAgICAgY2hlY2tMb2FkZWRUaW1lb3V0SWQsCiAgICAgICAgICAgIGNvbmZpZyA9IHsKICAgICAgICAgICAgICAgIC8vRGVmYXVsdHMuIERvIG5vdCBzZXQgYSBkZWZhdWx0IGZvciBtYXAKICAgICAgICAgICAgICAgIC8vY29uZmlnIHRvIHNwZWVkIHVwIG5vcm1hbGl6ZSgpLCB3aGljaAogICAgICAgICAgICAgICAgLy93aWxsIHJ1biBmYXN0ZXIgaWYgdGhlcmUgaXMgbm8gZGVmYXVsdC4KICAgICAgICAgICAgICAgIHdhaXRTZWNvbmRzOiA3LAogICAgICAgICAgICAgICAgYmFzZVVybDogJy4vJywKICAgICAgICAgICAgICAgIHBhdGhzOiB7fSwKICAgICAgICAgICAgICAgIGJ1bmRsZXM6IHt9LAogICAgICAgICAgICAgICAgcGtnczoge30sCiAgICAgICAgICAgICAgICBzaGltOiB7fSwKICAgICAgICAgICAgICAgIGNvbmZpZzoge30KICAgICAgICAgICAgfSwKICAgICAgICAgICAgcmVnaXN0cnkgPSB7fSwKICAgICAgICAgICAgLy9yZWdpc3RyeSBvZiBqdXN0IGVuYWJsZWQgbW9kdWxlcywgdG8gc3BlZWQKICAgICAgICAgICAgLy9jeWNsZSBicmVha2luZyBjb2RlIHdoZW4gbG90cyBvZiBtb2R1bGVzCiAgICAgICAgICAgIC8vYXJlIHJlZ2lzdGVyZWQsIGJ1dCBub3QgYWN0aXZhdGVkLgogICAgICAgICAgICBlbmFibGVkUmVnaXN0cnkgPSB7fSwKICAgICAgICAgICAgdW5kZWZFdmVudHMgPSB7fSwKICAgICAgICAgICAgZGVmUXVldWUgPSBbXSwKICAgICAgICAgICAgZGVmaW5lZCA9IHt9LAogICAgICAgICAgICB1cmxGZXRjaGVkID0ge30sCiAgICAgICAgICAgIGJ1bmRsZXNNYXAgPSB7fSwKICAgICAgICAgICAgcmVxdWlyZUNvdW50ZXIgPSAxLAogICAgICAgICAgICB1bm5vcm1hbGl6ZWRDb3VudGVyID0gMTsKCiAgICAgICAgLyoqCiAgICAgICAgICogVHJpbXMgdGhlIC4gYW5kIC4uIGZyb20gYW4gYXJyYXkgb2YgcGF0aCBzZWdtZW50cy4KICAgICAgICAgKiBJdCB3aWxsIGtlZXAgYSBsZWFkaW5nIHBhdGggc2VnbWVudCBpZiBhIC4uIHdpbGwgYmVjb21lCiAgICAgICAgICogdGhlIGZpcnN0IHBhdGggc2VnbWVudCwgdG8gaGVscCB3aXRoIG1vZHVsZSBuYW1lIGxvb2t1cHMsCiAgICAgICAgICogd2hpY2ggYWN0IGxpa2UgcGF0aHMsIGJ1dCBjYW4gYmUgcmVtYXBwZWQuIEJ1dCB0aGUgZW5kIHJlc3VsdCwKICAgICAgICAgKiBhbGwgcGF0aHMgdGhhdCB1c2UgdGhpcyBmdW5jdGlvbiBzaG91bGQgbG9vayBub3JtYWxpemVkLgogICAgICAgICAqIE5PVEU6IHRoaXMgbWV0aG9kIE1PRElGSUVTIHRoZSBpbnB1dCBhcnJheS4KICAgICAgICAgKiBAcGFyYW0ge0FycmF5fSBhcnkgdGhlIGFycmF5IG9mIHBhdGggc2VnbWVudHMuCiAgICAgICAgICovCiAgICAgICAgZnVuY3Rpb24gdHJpbURvdHMoYXJ5KSB7CiAgICAgICAgICAgIHZhciBpLCBwYXJ0OwogICAgICAgICAgICBmb3IgKGkgPSAwOyBpIDwgYXJ5Lmxlbmd0aDsgaSsrKSB7CiAgICAgICAgICAgICAgICBwYXJ0ID0gYXJ5W2ldOwogICAgICAgICAgICAgICAgaWYgKHBhcnQgPT09ICcuJykgewogICAgICAgICAgICAgICAgICAgIGFyeS5zcGxpY2UoaSwgMSk7CiAgICAgICAgICAgICAgICAgICAgaSAtPSAxOwogICAgICAgICAgICAgICAgfSBlbHNlIGlmIChwYXJ0ID09PSAnLi4nKSB7CiAgICAgICAgICAgICAgICAgICAgLy8gSWYgYXQgdGhlIHN0YXJ0LCBvciBwcmV2aW91cyB2YWx1ZSBpcyBzdGlsbCAuLiwKICAgICAgICAgICAgICAgICAgICAvLyBrZWVwIHRoZW0gc28gdGhhdCB3aGVuIGNvbnZlcnRlZCB0byBhIHBhdGggaXQgbWF5CiAgICAgICAgICAgICAgICAgICAgLy8gc3RpbGwgd29yayB3aGVuIGNvbnZlcnRlZCB0byBhIHBhdGgsIGV2ZW4gdGhvdWdoCiAgICAgICAgICAgICAgICAgICAgLy8gYXMgYW4gSUQgaXQgaXMgbGVzcyB0aGFuIGlkZWFsLiBJbiBsYXJnZXIgcG9pbnQKICAgICAgICAgICAgICAgICAgICAvLyByZWxlYXNlcywgbWF5IGJlIGJldHRlciB0byBqdXN0IGtpY2sgb3V0IGFuIGVycm9yLgogICAgICAgICAgICAgICAgICAgIGlmIChpID09PSAwIHx8IChpID09PSAxICYmIGFyeVsyXSA9PT0gJy4uJykgfHwgYXJ5W2kgLSAxXSA9PT0gJy4uJykgewogICAgICAgICAgICAgICAgICAgICAgICBjb250aW51ZTsKICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKGkgPiAwKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGFyeS5zcGxpY2UoaSAtIDEsIDIpOwogICAgICAgICAgICAgICAgICAgICAgICBpIC09IDI7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvKioKICAgICAgICAgKiBHaXZlbiBhIHJlbGF0aXZlIG1vZHVsZSBuYW1lLCBsaWtlIC4vc29tZXRoaW5nLCBub3JtYWxpemUgaXQgdG8KICAgICAgICAgKiBhIHJlYWwgbmFtZSB0aGF0IGNhbiBiZSBtYXBwZWQgdG8gYSBwYXRoLgogICAgICAgICAqIEBwYXJhbSB7U3RyaW5nfSBuYW1lIHRoZSByZWxhdGl2ZSBuYW1lCiAgICAgICAgICogQHBhcmFtIHtTdHJpbmd9IGJhc2VOYW1lIGEgcmVhbCBuYW1lIHRoYXQgdGhlIG5hbWUgYXJnIGlzIHJlbGF0aXZlCiAgICAgICAgICogdG8uCiAgICAgICAgICogQHBhcmFtIHtCb29sZWFufSBhcHBseU1hcCBhcHBseSB0aGUgbWFwIGNvbmZpZyB0byB0aGUgdmFsdWUuIFNob3VsZAogICAgICAgICAqIG9ubHkgYmUgZG9uZSBpZiB0aGlzIG5vcm1hbGl6YXRpb24gaXMgZm9yIGEgZGVwZW5kZW5jeSBJRC4KICAgICAgICAgKiBAcmV0dXJucyB7U3RyaW5nfSBub3JtYWxpemVkIG5hbWUKICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiBub3JtYWxpemUobmFtZSwgYmFzZU5hbWUsIGFwcGx5TWFwKSB7CiAgICAgICAgICAgIHZhciBwa2dNYWluLCBtYXBWYWx1ZSwgbmFtZVBhcnRzLCBpLCBqLCBuYW1lU2VnbWVudCwgbGFzdEluZGV4LAogICAgICAgICAgICAgICAgZm91bmRNYXAsIGZvdW5kSSwgZm91bmRTdGFyTWFwLCBzdGFySSwgbm9ybWFsaXplZEJhc2VQYXJ0cywKICAgICAgICAgICAgICAgIGJhc2VQYXJ0cyA9IChiYXNlTmFtZSAmJiBiYXNlTmFtZS5zcGxpdCgnLycpKSwKICAgICAgICAgICAgICAgIG1hcCA9IGNvbmZpZy5tYXAsCiAgICAgICAgICAgICAgICBzdGFyTWFwID0gbWFwICYmIG1hcFsnKiddOwoKICAgICAgICAgICAgLy9BZGp1c3QgYW55IHJlbGF0aXZlIHBhdGhzLgogICAgICAgICAgICBpZiAobmFtZSkgewogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuc3BsaXQoJy8nKTsKICAgICAgICAgICAgICAgIGxhc3RJbmRleCA9IG5hbWUubGVuZ3RoIC0gMTsKCiAgICAgICAgICAgICAgICAvLyBJZiB3YW50aW5nIG5vZGUgSUQgY29tcGF0aWJpbGl0eSwgc3RyaXAgLmpzIGZyb20gZW5kCiAgICAgICAgICAgICAgICAvLyBvZiBJRHMuIEhhdmUgdG8gZG8gdGhpcyBoZXJlLCBhbmQgbm90IGluIG5hbWVUb1VybAogICAgICAgICAgICAgICAgLy8gYmVjYXVzZSBub2RlIGFsbG93cyBlaXRoZXIgLmpzIG9yIG5vbiAuanMgdG8gbWFwCiAgICAgICAgICAgICAgICAvLyB0byBzYW1lIGZpbGUuCiAgICAgICAgICAgICAgICBpZiAoY29uZmlnLm5vZGVJZENvbXBhdCAmJiBqc1N1ZmZpeFJlZ0V4cC50ZXN0KG5hbWVbbGFzdEluZGV4XSkpIHsKICAgICAgICAgICAgICAgICAgICBuYW1lW2xhc3RJbmRleF0gPSBuYW1lW2xhc3RJbmRleF0ucmVwbGFjZShqc1N1ZmZpeFJlZ0V4cCwgJycpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vIFN0YXJ0cyB3aXRoIGEgJy4nIHNvIG5lZWQgdGhlIGJhc2VOYW1lCiAgICAgICAgICAgICAgICBpZiAobmFtZVswXS5jaGFyQXQoMCkgPT09ICcuJyAmJiBiYXNlUGFydHMpIHsKICAgICAgICAgICAgICAgICAgICAvL0NvbnZlcnQgYmFzZU5hbWUgdG8gYXJyYXksIGFuZCBsb3Agb2ZmIHRoZSBsYXN0IHBhcnQsCiAgICAgICAgICAgICAgICAgICAgLy9zbyB0aGF0IC4gbWF0Y2hlcyB0aGF0ICdkaXJlY3RvcnknIGFuZCBub3QgbmFtZSBvZiB0aGUgYmFzZU5hbWUncwogICAgICAgICAgICAgICAgICAgIC8vbW9kdWxlLiBGb3IgaW5zdGFuY2UsIGJhc2VOYW1lIG9mICdvbmUvdHdvL3RocmVlJywgbWFwcyB0bwogICAgICAgICAgICAgICAgICAgIC8vJ29uZS90d28vdGhyZWUuanMnLCBidXQgd2Ugd2FudCB0aGUgZGlyZWN0b3J5LCAnb25lL3R3bycgZm9yCiAgICAgICAgICAgICAgICAgICAgLy90aGlzIG5vcm1hbGl6YXRpb24uCiAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZEJhc2VQYXJ0cyA9IGJhc2VQYXJ0cy5zbGljZSgwLCBiYXNlUGFydHMubGVuZ3RoIC0gMSk7CiAgICAgICAgICAgICAgICAgICAgbmFtZSA9IG5vcm1hbGl6ZWRCYXNlUGFydHMuY29uY2F0KG5hbWUpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHRyaW1Eb3RzKG5hbWUpOwogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuam9pbignLycpOwogICAgICAgICAgICB9CgogICAgICAgICAgICAvL0FwcGx5IG1hcCBjb25maWcgaWYgYXZhaWxhYmxlLgogICAgICAgICAgICBpZiAoYXBwbHlNYXAgJiYgbWFwICYmIChiYXNlUGFydHMgfHwgc3Rhck1hcCkpIHsKICAgICAgICAgICAgICAgIG5hbWVQYXJ0cyA9IG5hbWUuc3BsaXQoJy8nKTsKCiAgICAgICAgICAgICAgICBvdXRlckxvb3A6IGZvciAoaSA9IG5hbWVQYXJ0cy5sZW5ndGg7IGkgPiAwOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICBuYW1lU2VnbWVudCA9IG5hbWVQYXJ0cy5zbGljZSgwLCBpKS5qb2luKCcvJyk7CgogICAgICAgICAgICAgICAgICAgIGlmIChiYXNlUGFydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9GaW5kIHRoZSBsb25nZXN0IGJhc2VOYW1lIHNlZ21lbnQgbWF0Y2ggaW4gdGhlIGNvbmZpZy4KICAgICAgICAgICAgICAgICAgICAgICAgLy9TbywgZG8gam9pbnMgb24gdGhlIGJpZ2dlc3QgdG8gc21hbGxlc3QgbGVuZ3RocyBvZiBiYXNlUGFydHMuCiAgICAgICAgICAgICAgICAgICAgICAgIGZvciAoaiA9IGJhc2VQYXJ0cy5sZW5ndGg7IGogPiAwOyBqIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG1hcFZhbHVlID0gZ2V0T3duKG1hcCwgYmFzZVBhcnRzLnNsaWNlKDAsIGopLmpvaW4oJy8nKSk7CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9iYXNlTmFtZSBzZWdtZW50IGhhcyBjb25maWcsIGZpbmQgaWYgaXQgaGFzIG9uZSBmb3IKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vdGhpcyBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1hcFZhbHVlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbWFwVmFsdWUgPSBnZXRPd24obWFwVmFsdWUsIG5hbWVTZWdtZW50KTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAobWFwVmFsdWUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9NYXRjaCwgdXBkYXRlIG5hbWUgdG8gdGhlIG5ldyB2YWx1ZS4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZm91bmRNYXAgPSBtYXBWYWx1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZm91bmRJID0gaTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWsgb3V0ZXJMb29wOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgLy9DaGVjayBmb3IgYSBzdGFyIG1hcCBtYXRjaCwgYnV0IGp1c3QgaG9sZCBvbiB0byBpdCwKICAgICAgICAgICAgICAgICAgICAvL2lmIHRoZXJlIGlzIGEgc2hvcnRlciBzZWdtZW50IG1hdGNoIGxhdGVyIGluIGEgbWF0Y2hpbmcKICAgICAgICAgICAgICAgICAgICAvL2NvbmZpZywgdGhlbiBmYXZvciBvdmVyIHRoaXMgc3RhciBtYXAuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFmb3VuZFN0YXJNYXAgJiYgc3Rhck1hcCAmJiBnZXRPd24oc3Rhck1hcCwgbmFtZVNlZ21lbnQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGZvdW5kU3Rhck1hcCA9IGdldE93bihzdGFyTWFwLCBuYW1lU2VnbWVudCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHN0YXJJID0gaTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFmb3VuZE1hcCAmJiBmb3VuZFN0YXJNYXApIHsKICAgICAgICAgICAgICAgICAgICBmb3VuZE1hcCA9IGZvdW5kU3Rhck1hcDsKICAgICAgICAgICAgICAgICAgICBmb3VuZEkgPSBzdGFySTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBpZiAoZm91bmRNYXApIHsKICAgICAgICAgICAgICAgICAgICBuYW1lUGFydHMuc3BsaWNlKDAsIGZvdW5kSSwgZm91bmRNYXApOwogICAgICAgICAgICAgICAgICAgIG5hbWUgPSBuYW1lUGFydHMuam9pbignLycpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CgogICAgICAgICAgICAvLyBJZiB0aGUgbmFtZSBwb2ludHMgdG8gYSBwYWNrYWdlJ3MgbmFtZSwgdXNlCiAgICAgICAgICAgIC8vIHRoZSBwYWNrYWdlIG1haW4gaW5zdGVhZC4KICAgICAgICAgICAgcGtnTWFpbiA9IGdldE93bihjb25maWcucGtncywgbmFtZSk7CgogICAgICAgICAgICByZXR1cm4gcGtnTWFpbiA/IHBrZ01haW4gOiBuYW1lOwogICAgICAgIH0KCiAgICAgICAgZnVuY3Rpb24gcmVtb3ZlU2NyaXB0KG5hbWUpIHsKICAgICAgICAgICAgaWYgKGlzQnJvd3NlcikgewogICAgICAgICAgICAgICAgZWFjaChzY3JpcHRzKCksIGZ1bmN0aW9uIChzY3JpcHROb2RlKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKHNjcmlwdE5vZGUuZ2V0QXR0cmlidXRlKCdkYXRhLXJlcXVpcmVtb2R1bGUnKSA9PT0gbmFtZSAmJgogICAgICAgICAgICAgICAgICAgICAgICAgICAgc2NyaXB0Tm9kZS5nZXRBdHRyaWJ1dGUoJ2RhdGEtcmVxdWlyZWNvbnRleHQnKSA9PT0gY29udGV4dC5jb250ZXh0TmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICBzY3JpcHROb2RlLnBhcmVudE5vZGUucmVtb3ZlQ2hpbGQoc2NyaXB0Tm9kZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBoYXNQYXRoRmFsbGJhY2soaWQpIHsKICAgICAgICAgICAgdmFyIHBhdGhDb25maWcgPSBnZXRPd24oY29uZmlnLnBhdGhzLCBpZCk7CiAgICAgICAgICAgIGlmIChwYXRoQ29uZmlnICYmIGlzQXJyYXkocGF0aENvbmZpZykgJiYgcGF0aENvbmZpZy5sZW5ndGggPiAxKSB7CiAgICAgICAgICAgICAgICAvL1BvcCBvZmYgdGhlIGZpcnN0IGFycmF5IHZhbHVlLCBzaW5jZSBpdCBmYWlsZWQsIGFuZAogICAgICAgICAgICAgICAgLy9yZXRyeQogICAgICAgICAgICAgICAgcGF0aENvbmZpZy5zaGlmdCgpOwogICAgICAgICAgICAgICAgY29udGV4dC5yZXF1aXJlLnVuZGVmKGlkKTsKCiAgICAgICAgICAgICAgICAvL0N1c3RvbSByZXF1aXJlIHRoYXQgZG9lcyBub3QgZG8gbWFwIHRyYW5zbGF0aW9uLCBzaW5jZQogICAgICAgICAgICAgICAgLy9JRCBpcyAiYWJzb2x1dGUiLCBhbHJlYWR5IG1hcHBlZC9yZXNvbHZlZC4KICAgICAgICAgICAgICAgIGNvbnRleHQubWFrZVJlcXVpcmUobnVsbCwgewogICAgICAgICAgICAgICAgICAgIHNraXBNYXA6IHRydWUKICAgICAgICAgICAgICAgIH0pKFtpZF0pOwoKICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvL1R1cm5zIGEgcGx1Z2luIXJlc291cmNlIHRvIFtwbHVnaW4sIHJlc291cmNlXQogICAgICAgIC8vd2l0aCB0aGUgcGx1Z2luIGJlaW5nIHVuZGVmaW5lZCBpZiB0aGUgbmFtZQogICAgICAgIC8vZGlkIG5vdCBoYXZlIGEgcGx1Z2luIHByZWZpeC4KICAgICAgICBmdW5jdGlvbiBzcGxpdFByZWZpeChuYW1lKSB7CiAgICAgICAgICAgIHZhciBwcmVmaXgsCiAgICAgICAgICAgICAgICBpbmRleCA9IG5hbWUgPyBuYW1lLmluZGV4T2YoJyEnKSA6IC0xOwogICAgICAgICAgICBpZiAoaW5kZXggPiAtMSkgewogICAgICAgICAgICAgICAgcHJlZml4ID0gbmFtZS5zdWJzdHJpbmcoMCwgaW5kZXgpOwogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuc3Vic3RyaW5nKGluZGV4ICsgMSwgbmFtZS5sZW5ndGgpOwogICAgICAgICAgICB9CiAgICAgICAgICAgIHJldHVybiBbcHJlZml4LCBuYW1lXTsKICAgICAgICB9CgogICAgICAgIC8qKgogICAgICAgICAqIENyZWF0ZXMgYSBtb2R1bGUgbWFwcGluZyB0aGF0IGluY2x1ZGVzIHBsdWdpbiBwcmVmaXgsIG1vZHVsZQogICAgICAgICAqIG5hbWUsIGFuZCBwYXRoLiBJZiBwYXJlbnRNb2R1bGVNYXAgaXMgcHJvdmlkZWQgaXQgd2lsbAogICAgICAgICAqIGFsc28gbm9ybWFsaXplIHRoZSBuYW1lIHZpYSByZXF1aXJlLm5vcm1hbGl6ZSgpCiAgICAgICAgICoKICAgICAgICAgKiBAcGFyYW0ge1N0cmluZ30gbmFtZSB0aGUgbW9kdWxlIG5hbWUKICAgICAgICAgKiBAcGFyYW0ge1N0cmluZ30gW3BhcmVudE1vZHVsZU1hcF0gcGFyZW50IG1vZHVsZSBtYXAKICAgICAgICAgKiBmb3IgdGhlIG1vZHVsZSBuYW1lLCB1c2VkIHRvIHJlc29sdmUgcmVsYXRpdmUgbmFtZXMuCiAgICAgICAgICogQHBhcmFtIHtCb29sZWFufSBpc05vcm1hbGl6ZWQ6IGlzIHRoZSBJRCBhbHJlYWR5IG5vcm1hbGl6ZWQuCiAgICAgICAgICogVGhpcyBpcyB0cnVlIGlmIHRoaXMgY2FsbCBpcyBkb25lIGZvciBhIGRlZmluZSgpIG1vZHVsZSBJRC4KICAgICAgICAgKiBAcGFyYW0ge0Jvb2xlYW59IGFwcGx5TWFwOiBhcHBseSB0aGUgbWFwIGNvbmZpZyB0byB0aGUgSUQuCiAgICAgICAgICogU2hvdWxkIG9ubHkgYmUgdHJ1ZSBpZiB0aGlzIG1hcCBpcyBmb3IgYSBkZXBlbmRlbmN5LgogICAgICAgICAqCiAgICAgICAgICogQHJldHVybnMge09iamVjdH0KICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiBtYWtlTW9kdWxlTWFwKG5hbWUsIHBhcmVudE1vZHVsZU1hcCwgaXNOb3JtYWxpemVkLCBhcHBseU1hcCkgewogICAgICAgICAgICB2YXIgdXJsLCBwbHVnaW5Nb2R1bGUsIHN1ZmZpeCwgbmFtZVBhcnRzLAogICAgICAgICAgICAgICAgcHJlZml4ID0gbnVsbCwKICAgICAgICAgICAgICAgIHBhcmVudE5hbWUgPSBwYXJlbnRNb2R1bGVNYXAgPyBwYXJlbnRNb2R1bGVNYXAubmFtZSA6IG51bGwsCiAgICAgICAgICAgICAgICBvcmlnaW5hbE5hbWUgPSBuYW1lLAogICAgICAgICAgICAgICAgaXNEZWZpbmUgPSB0cnVlLAogICAgICAgICAgICAgICAgbm9ybWFsaXplZE5hbWUgPSAnJzsKCiAgICAgICAgICAgIC8vSWYgbm8gbmFtZSwgdGhlbiBpdCBtZWFucyBpdCBpcyBhIHJlcXVpcmUgY2FsbCwgZ2VuZXJhdGUgYW4KICAgICAgICAgICAgLy9pbnRlcm5hbCBuYW1lLgogICAgICAgICAgICBpZiAoIW5hbWUpIHsKICAgICAgICAgICAgICAgIGlzRGVmaW5lID0gZmFsc2U7CiAgICAgICAgICAgICAgICBuYW1lID0gJ19AcicgKyAocmVxdWlyZUNvdW50ZXIgKz0gMSk7CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIG5hbWVQYXJ0cyA9IHNwbGl0UHJlZml4KG5hbWUpOwogICAgICAgICAgICBwcmVmaXggPSBuYW1lUGFydHNbMF07CiAgICAgICAgICAgIG5hbWUgPSBuYW1lUGFydHNbMV07CgogICAgICAgICAgICBpZiAocHJlZml4KSB7CiAgICAgICAgICAgICAgICBwcmVmaXggPSBub3JtYWxpemUocHJlZml4LCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CiAgICAgICAgICAgICAgICBwbHVnaW5Nb2R1bGUgPSBnZXRPd24oZGVmaW5lZCwgcHJlZml4KTsKICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9BY2NvdW50IGZvciByZWxhdGl2ZSBwYXRocyBpZiB0aGVyZSBpcyBhIGJhc2UgbmFtZS4KICAgICAgICAgICAgaWYgKG5hbWUpIHsKICAgICAgICAgICAgICAgIGlmIChwcmVmaXgpIHsKICAgICAgICAgICAgICAgICAgICBpZiAocGx1Z2luTW9kdWxlICYmIHBsdWdpbk1vZHVsZS5ub3JtYWxpemUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9QbHVnaW4gaXMgbG9hZGVkLCB1c2UgaXRzIG5vcm1hbGl6ZSBtZXRob2QuCiAgICAgICAgICAgICAgICAgICAgICAgIG5vcm1hbGl6ZWROYW1lID0gcGx1Z2luTW9kdWxlLm5vcm1hbGl6ZShuYW1lLCBmdW5jdGlvbiAobmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vIElmIG5lc3RlZCBwbHVnaW4gcmVmZXJlbmNlcywgdGhlbiBkbyBub3QgdHJ5IHRvCiAgICAgICAgICAgICAgICAgICAgICAgIC8vIG5vcm1hbGl6ZSwgYXMgaXQgd2lsbCBub3Qgbm9ybWFsaXplIGNvcnJlY3RseS4gVGhpcwogICAgICAgICAgICAgICAgICAgICAgICAvLyBwbGFjZXMgYSByZXN0cmljdGlvbiBvbiByZXNvdXJjZUlkcywgYW5kIHRoZSBsb25nZXIKICAgICAgICAgICAgICAgICAgICAgICAgLy8gdGVybSBzb2x1dGlvbiBpcyBub3QgdG8gbm9ybWFsaXplIHVudGlsIHBsdWdpbnMgYXJlCiAgICAgICAgICAgICAgICAgICAgICAgIC8vIGxvYWRlZCBhbmQgYWxsIG5vcm1hbGl6YXRpb25zIHRvIGFsbG93IGZvciBhc3luYwogICAgICAgICAgICAgICAgICAgICAgICAvLyBsb2FkaW5nIG9mIGEgbG9hZGVyIHBsdWdpbi4gQnV0IGZvciBub3csIGZpeGVzIHRoZQogICAgICAgICAgICAgICAgICAgICAgICAvLyBjb21tb24gdXNlcy4gRGV0YWlscyBpbiAjMTEzMQogICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTmFtZSA9IG5hbWUuaW5kZXhPZignIScpID09PSAtMSA/CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplKG5hbWUsIHBhcmVudE5hbWUsIGFwcGx5TWFwKSA6CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbmFtZTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgIC8vQSByZWd1bGFyIG1vZHVsZS4KICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTmFtZSA9IG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CgogICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplZCBuYW1lIG1heSBiZSBhIHBsdWdpbiBJRCBkdWUgdG8gbWFwIGNvbmZpZwogICAgICAgICAgICAgICAgICAgIC8vYXBwbGljYXRpb24gaW4gbm9ybWFsaXplLiBUaGUgbWFwIGNvbmZpZyB2YWx1ZXMgbXVzdAogICAgICAgICAgICAgICAgICAgIC8vYWxyZWFkeSBiZSBub3JtYWxpemVkLCBzbyBkbyBub3QgbmVlZCB0byByZWRvIHRoYXQgcGFydC4KICAgICAgICAgICAgICAgICAgICBuYW1lUGFydHMgPSBzcGxpdFByZWZpeChub3JtYWxpemVkTmFtZSk7CiAgICAgICAgICAgICAgICAgICAgcHJlZml4ID0gbmFtZVBhcnRzWzBdOwogICAgICAgICAgICAgICAgICAgIG5vcm1hbGl6ZWROYW1lID0gbmFtZVBhcnRzWzFdOwogICAgICAgICAgICAgICAgICAgIGlzTm9ybWFsaXplZCA9IHRydWU7CgogICAgICAgICAgICAgICAgICAgIHVybCA9IGNvbnRleHQubmFtZVRvVXJsKG5vcm1hbGl6ZWROYW1lKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9JZiB0aGUgaWQgaXMgYSBwbHVnaW4gaWQgdGhhdCBjYW5ub3QgYmUgZGV0ZXJtaW5lZCBpZiBpdCBuZWVkcwogICAgICAgICAgICAvL25vcm1hbGl6YXRpb24sIHN0YW1wIGl0IHdpdGggYSB1bmlxdWUgSUQgc28gdHdvIG1hdGNoaW5nIHJlbGF0aXZlCiAgICAgICAgICAgIC8vaWRzIHRoYXQgbWF5IGNvbmZsaWN0IGNhbiBiZSBzZXBhcmF0ZS4KICAgICAgICAgICAgc3VmZml4ID0gcHJlZml4ICYmICFwbHVnaW5Nb2R1bGUgJiYgIWlzTm9ybWFsaXplZCA/CiAgICAgICAgICAgICAgICAgICAgICdfdW5ub3JtYWxpemVkJyArICh1bm5vcm1hbGl6ZWRDb3VudGVyICs9IDEpIDoKICAgICAgICAgICAgICAgICAgICAgJyc7CgogICAgICAgICAgICByZXR1cm4gewogICAgICAgICAgICAgICAgcHJlZml4OiBwcmVmaXgsCiAgICAgICAgICAgICAgICBuYW1lOiBub3JtYWxpemVkTmFtZSwKICAgICAgICAgICAgICAgIHBhcmVudE1hcDogcGFyZW50TW9kdWxlTWFwLAogICAgICAgICAgICAgICAgdW5ub3JtYWxpemVkOiAhIXN1ZmZpeCwKICAgICAgICAgICAgICAgIHVybDogdXJsLAogICAgICAgICAgICAgICAgb3JpZ2luYWxOYW1lOiBvcmlnaW5hbE5hbWUsCiAgICAgICAgICAgICAgICBpc0RlZmluZTogaXNEZWZpbmUsCiAgICAgICAgICAgICAgICBpZDogKHByZWZpeCA/CiAgICAgICAgICAgICAgICAgICAgICAgIHByZWZpeCArICchJyArIG5vcm1hbGl6ZWROYW1lIDoKICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZE5hbWUpICsgc3VmZml4CiAgICAgICAgICAgIH07CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBnZXRNb2R1bGUoZGVwTWFwKSB7CiAgICAgICAgICAgIHZhciBpZCA9IGRlcE1hcC5pZCwKICAgICAgICAgICAgICAgIG1vZCA9IGdldE93bihyZWdpc3RyeSwgaWQpOwoKICAgICAgICAgICAgaWYgKCFtb2QpIHsKICAgICAgICAgICAgICAgIG1vZCA9IHJlZ2lzdHJ5W2lkXSA9IG5ldyBjb250ZXh0Lk1vZHVsZShkZXBNYXApOwogICAgICAgICAgICB9CgogICAgICAgICAgICByZXR1cm4gbW9kOwogICAgICAgIH0KCiAgICAgICAgZnVuY3Rpb24gb24oZGVwTWFwLCBuYW1lLCBmbikgewogICAgICAgICAgICB2YXIgaWQgPSBkZXBNYXAuaWQsCiAgICAgICAgICAgICAgICBtb2QgPSBnZXRPd24ocmVnaXN0cnksIGlkKTsKCiAgICAgICAgICAgIGlmIChoYXNQcm9wKGRlZmluZWQsIGlkKSAmJgogICAgICAgICAgICAgICAgICAgICghbW9kIHx8IG1vZC5kZWZpbmVFbWl0Q29tcGxldGUpKSB7CiAgICAgICAgICAgICAgICBpZiAobmFtZSA9PT0gJ2RlZmluZWQnKSB7CiAgICAgICAgICAgICAgICAgICAgZm4oZGVmaW5lZFtpZF0pOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgbW9kID0gZ2V0TW9kdWxlKGRlcE1hcCk7CiAgICAgICAgICAgICAgICBpZiAobW9kLmVycm9yICYmIG5hbWUgPT09ICdlcnJvcicpIHsKICAgICAgICAgICAgICAgICAgICBmbihtb2QuZXJyb3IpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICBtb2Qub24obmFtZSwgZm4pOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBvbkVycm9yKGVyciwgZXJyYmFjaykgewogICAgICAgICAgICB2YXIgaWRzID0gZXJyLnJlcXVpcmVNb2R1bGVzLAogICAgICAgICAgICAgICAgbm90aWZpZWQgPSBmYWxzZTsKCiAgICAgICAgICAgIGlmIChlcnJiYWNrKSB7CiAgICAgICAgICAgICAgICBlcnJiYWNrKGVycik7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBlYWNoKGlkcywgZnVuY3Rpb24gKGlkKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIG1vZCA9IGdldE93bihyZWdpc3RyeSwgaWQpOwogICAgICAgICAgICAgICAgICAgIGlmIChtb2QpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9TZXQgZXJyb3Igb24gbW9kdWxlLCBzbyBpdCBza2lwcyB0aW1lb3V0IGNoZWNrcy4KICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmVycm9yID0gZXJyOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAobW9kLmV2ZW50cy5lcnJvcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgbm90aWZpZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgIGlmICghbm90aWZpZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXEub25FcnJvcihlcnIpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvKioKICAgICAgICAgKiBJbnRlcm5hbCBtZXRob2QgdG8gdHJhbnNmZXIgZ2xvYmFsUXVldWUgaXRlbXMgdG8gdGhpcyBjb250ZXh0J3MKICAgICAgICAgKiBkZWZRdWV1ZS4KICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiB0YWtlR2xvYmFsUXVldWUoKSB7CiAgICAgICAgICAgIC8vUHVzaCBhbGwgdGhlIGdsb2JhbERlZlF1ZXVlIGl0ZW1zIGludG8gdGhlIGNvbnRleHQncyBkZWZRdWV1ZQogICAgICAgICAgICBpZiAoZ2xvYmFsRGVmUXVldWUubGVuZ3RoKSB7CiAgICAgICAgICAgICAgICBlYWNoKGdsb2JhbERlZlF1ZXVlLCBmdW5jdGlvbihxdWV1ZUl0ZW0pIHsKICAgICAgICAgICAgICAgICAgICB2YXIgaWQgPSBxdWV1ZUl0ZW1bMF07CiAgICAgICAgICAgICAgICAgICAgaWYgKHR5cGVvZiBpZCA9PT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZU1hcFtpZF0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICBkZWZRdWV1ZS5wdXNoKHF1ZXVlSXRlbSk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIGdsb2JhbERlZlF1ZXVlID0gW107CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIGhhbmRsZXJzID0gewogICAgICAgICAgICAncmVxdWlyZSc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIGlmIChtb2QucmVxdWlyZSkgewogICAgICAgICAgICAgICAgICAgIHJldHVybiBtb2QucmVxdWlyZTsKICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChtb2QucmVxdWlyZSA9IGNvbnRleHQubWFrZVJlcXVpcmUobW9kLm1hcCkpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAogICAgICAgICAgICAnZXhwb3J0cyc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIG1vZC51c2luZ0V4cG9ydHMgPSB0cnVlOwogICAgICAgICAgICAgICAgaWYgKG1vZC5tYXAuaXNEZWZpbmUpIHsKICAgICAgICAgICAgICAgICAgICBpZiAobW9kLmV4cG9ydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChkZWZpbmVkW21vZC5tYXAuaWRdID0gbW9kLmV4cG9ydHMpOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiAobW9kLmV4cG9ydHMgPSBkZWZpbmVkW21vZC5tYXAuaWRdID0ge30pOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKICAgICAgICAgICAgJ21vZHVsZSc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIGlmIChtb2QubW9kdWxlKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG1vZC5tb2R1bGU7CiAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgIHJldHVybiAobW9kLm1vZHVsZSA9IHsKICAgICAgICAgICAgICAgICAgICAgICAgaWQ6IG1vZC5tYXAuaWQsCiAgICAgICAgICAgICAgICAgICAgICAgIHVyaTogbW9kLm1hcC51cmwsCiAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZzogZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGdldE93bihjb25maWcuY29uZmlnLCBtb2QubWFwLmlkKSB8fCB7fTsKICAgICAgICAgICAgICAgICAgICAgICAgfSwKICAgICAgICAgICAgICAgICAgICAgICAgZXhwb3J0czogbW9kLmV4cG9ydHMgfHwgKG1vZC5leHBvcnRzID0ge30pCiAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBmdW5jdGlvbiBjbGVhblJlZ2lzdHJ5KGlkKSB7CiAgICAgICAgICAgIC8vQ2xlYW4gdXAgbWFjaGluZXJ5IHVzZWQgZm9yIHdhaXRpbmcgbW9kdWxlcy4KICAgICAgICAgICAgZGVsZXRlIHJlZ2lzdHJ5W2lkXTsKICAgICAgICAgICAgZGVsZXRlIGVuYWJsZWRSZWdpc3RyeVtpZF07CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBicmVha0N5Y2xlKG1vZCwgdHJhY2VkLCBwcm9jZXNzZWQpIHsKICAgICAgICAgICAgdmFyIGlkID0gbW9kLm1hcC5pZDsKCiAgICAgICAgICAgIGlmIChtb2QuZXJyb3IpIHsKICAgICAgICAgICAgICAgIG1vZC5lbWl0KCdlcnJvcicsIG1vZC5lcnJvcik7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICB0cmFjZWRbaWRdID0gdHJ1ZTsKICAgICAgICAgICAgICAgIGVhY2gobW9kLmRlcE1hcHMsIGZ1bmN0aW9uIChkZXBNYXAsIGkpIHsKICAgICAgICAgICAgICAgICAgICB2YXIgZGVwSWQgPSBkZXBNYXAuaWQsCiAgICAgICAgICAgICAgICAgICAgICAgIGRlcCA9IGdldE93bihyZWdpc3RyeSwgZGVwSWQpOwoKICAgICAgICAgICAgICAgICAgICAvL09ubHkgZm9yY2UgdGhpbmdzIHRoYXQgaGF2ZSBub3QgY29tcGxldGVkCiAgICAgICAgICAgICAgICAgICAgLy9iZWluZyBkZWZpbmVkLCBzbyBzdGlsbCBpbiB0aGUgcmVnaXN0cnksCiAgICAgICAgICAgICAgICAgICAgLy9hbmQgb25seSBpZiBpdCBoYXMgbm90IGJlZW4gbWF0Y2hlZCB1cAogICAgICAgICAgICAgICAgICAgIC8vaW4gdGhlIG1vZHVsZSBhbHJlYWR5LgogICAgICAgICAgICAgICAgICAgIGlmIChkZXAgJiYgIW1vZC5kZXBNYXRjaGVkW2ldICYmICFwcm9jZXNzZWRbZGVwSWRdKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChnZXRPd24odHJhY2VkLCBkZXBJZCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG1vZC5kZWZpbmVEZXAoaSwgZGVmaW5lZFtkZXBJZF0pOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmNoZWNrKCk7IC8vcGFzcyBmYWxzZT8KICAgICAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGJyZWFrQ3ljbGUoZGVwLCB0cmFjZWQsIHByb2Nlc3NlZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIHByb2Nlc3NlZFtpZF0gPSB0cnVlOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBjaGVja0xvYWRlZCgpIHsKICAgICAgICAgICAgdmFyIGVyciwgdXNpbmdQYXRoRmFsbGJhY2ssCiAgICAgICAgICAgICAgICB3YWl0SW50ZXJ2YWwgPSBjb25maWcud2FpdFNlY29uZHMgKiAxMDAwLAogICAgICAgICAgICAgICAgLy9JdCBpcyBwb3NzaWJsZSB0byBkaXNhYmxlIHRoZSB3YWl0IGludGVydmFsIGJ5IHVzaW5nIHdhaXRTZWNvbmRzIG9mIDAuCiAgICAgICAgICAgICAgICBleHBpcmVkID0gd2FpdEludGVydmFsICYmIChjb250ZXh0LnN0YXJ0VGltZSArIHdhaXRJbnRlcnZhbCkgPCBuZXcgRGF0ZSgpLmdldFRpbWUoKSwKICAgICAgICAgICAgICAgIG5vTG9hZHMgPSBbXSwKICAgICAgICAgICAgICAgIHJlcUNhbGxzID0gW10sCiAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSBmYWxzZSwKICAgICAgICAgICAgICAgIG5lZWRDeWNsZUNoZWNrID0gdHJ1ZTsKCiAgICAgICAgICAgIC8vRG8gbm90IGJvdGhlciBpZiB0aGlzIGNhbGwgd2FzIGEgcmVzdWx0IG9mIGEgY3ljbGUgYnJlYWsuCiAgICAgICAgICAgIGlmIChpbkNoZWNrTG9hZGVkKSB7CiAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIGluQ2hlY2tMb2FkZWQgPSB0cnVlOwoKICAgICAgICAgICAgLy9GaWd1cmUgb3V0IHRoZSBzdGF0ZSBvZiBhbGwgdGhlIG1vZHVsZXMuCiAgICAgICAgICAgIGVhY2hQcm9wKGVuYWJsZWRSZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCkgewogICAgICAgICAgICAgICAgdmFyIG1hcCA9IG1vZC5tYXAsCiAgICAgICAgICAgICAgICAgICAgbW9kSWQgPSBtYXAuaWQ7CgogICAgICAgICAgICAgICAgLy9Ta2lwIHRoaW5ncyB0aGF0IGFyZSBub3QgZW5hYmxlZCBvciBpbiBlcnJvciBzdGF0ZS4KICAgICAgICAgICAgICAgIGlmICghbW9kLmVuYWJsZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFtYXAuaXNEZWZpbmUpIHsKICAgICAgICAgICAgICAgICAgICByZXFDYWxscy5wdXNoKG1vZCk7CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFtb2QuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICAvL0lmIHRoZSBtb2R1bGUgc2hvdWxkIGJlIGV4ZWN1dGVkLCBhbmQgaXQgaGFzIG5vdAogICAgICAgICAgICAgICAgICAgIC8vYmVlbiBpbml0ZWQgYW5kIHRpbWUgaXMgdXAsIHJlbWVtYmVyIGl0LgogICAgICAgICAgICAgICAgICAgIGlmICghbW9kLmluaXRlZCAmJiBleHBpcmVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChoYXNQYXRoRmFsbGJhY2sobW9kSWQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB1c2luZ1BhdGhGYWxsYmFjayA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgbm9Mb2Fkcy5wdXNoKG1vZElkKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlbW92ZVNjcmlwdChtb2RJZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKCFtb2QuaW5pdGVkICYmIG1vZC5mZXRjaGVkICYmIG1hcC5pc0RlZmluZSkgewogICAgICAgICAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAoIW1hcC5wcmVmaXgpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vTm8gcmVhc29uIHRvIGtlZXAgbG9va2luZyBmb3IgdW5maW5pc2hlZAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9sb2FkaW5nLiBJZiB0aGUgb25seSBzdGlsbExvYWRpbmcgaXMgYQogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9wbHVnaW4gcmVzb3VyY2UgdGhvdWdoLCBrZWVwIGdvaW5nLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9iZWNhdXNlIGl0IG1heSBiZSB0aGF0IGEgcGx1Z2luIHJlc291cmNlCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL2lzIHdhaXRpbmcgb24gYSBub24tcGx1Z2luIGN5Y2xlLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChuZWVkQ3ljbGVDaGVjayA9IGZhbHNlKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSk7CgogICAgICAgICAgICBpZiAoZXhwaXJlZCAmJiBub0xvYWRzLmxlbmd0aCkgewogICAgICAgICAgICAgICAgLy9JZiB3YWl0IHRpbWUgZXhwaXJlZCwgdGhyb3cgZXJyb3Igb2YgdW5sb2FkZWQgbW9kdWxlcy4KICAgICAgICAgICAgICAgIGVyciA9IG1ha2VFcnJvcigndGltZW91dCcsICdMb2FkIHRpbWVvdXQgZm9yIG1vZHVsZXM6ICcgKyBub0xvYWRzLCBudWxsLCBub0xvYWRzKTsKICAgICAgICAgICAgICAgIGVyci5jb250ZXh0TmFtZSA9IGNvbnRleHQuY29udGV4dE5hbWU7CiAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihlcnIpOwogICAgICAgICAgICB9CgogICAgICAgICAgICAvL05vdCBleHBpcmVkLCBjaGVjayBmb3IgYSBjeWNsZS4KICAgICAgICAgICAgaWYgKG5lZWRDeWNsZUNoZWNrKSB7CiAgICAgICAgICAgICAgICBlYWNoKHJlcUNhbGxzLCBmdW5jdGlvbiAobW9kKSB7CiAgICAgICAgICAgICAgICAgICAgYnJlYWtDeWNsZShtb2QsIHt9LCB7fSk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9JZiBzdGlsbCB3YWl0aW5nIG9uIGxvYWRzLCBhbmQgdGhlIHdhaXRpbmcgbG9hZCBpcyBzb21ldGhpbmcKICAgICAgICAgICAgLy9vdGhlciB0aGFuIGEgcGx1Z2luIHJlc291cmNlLCBvciB0aGVyZSBhcmUgc3RpbGwgb3V0c3RhbmRpbmcKICAgICAgICAgICAgLy9zY3JpcHRzLCB0aGVuIGp1c3QgdHJ5IGJhY2sgbGF0ZXIuCiAgICAgICAgICAgIGlmICgoIWV4cGlyZWQgfHwgdXNpbmdQYXRoRmFsbGJhY2spICYmIHN0aWxsTG9hZGluZykgewogICAgICAgICAgICAgICAgLy9Tb21ldGhpbmcgaXMgc3RpbGwgd2FpdGluZyB0byBsb2FkLiBXYWl0IGZvciBpdCwgYnV0IG9ubHkKICAgICAgICAgICAgICAgIC8vaWYgYSB0aW1lb3V0IGlzIG5vdCBhbHJlYWR5IGluIGVmZmVjdC4KICAgICAgICAgICAgICAgIGlmICgoaXNCcm93c2VyIHx8IGlzV2ViV29ya2VyKSAmJiAhY2hlY2tMb2FkZWRUaW1lb3V0SWQpIHsKICAgICAgICAgICAgICAgICAgICBjaGVja0xvYWRlZFRpbWVvdXRJZCA9IHNldFRpbWVvdXQoZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICBjaGVja0xvYWRlZFRpbWVvdXRJZCA9IDA7CiAgICAgICAgICAgICAgICAgICAgICAgIGNoZWNrTG9hZGVkKCk7CiAgICAgICAgICAgICAgICAgICAgfSwgNTApOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CgogICAgICAgICAgICBpbkNoZWNrTG9hZGVkID0gZmFsc2U7CiAgICAgICAgfQoKICAgICAgICBNb2R1bGUgPSBmdW5jdGlvbiAobWFwKSB7CiAgICAgICAgICAgIHRoaXMuZXZlbnRzID0gZ2V0T3duKHVuZGVmRXZlbnRzLCBtYXAuaWQpIHx8IHt9OwogICAgICAgICAgICB0aGlzLm1hcCA9IG1hcDsKICAgICAgICAgICAgdGhpcy5zaGltID0gZ2V0T3duKGNvbmZpZy5zaGltLCBtYXAuaWQpOwogICAgICAgICAgICB0aGlzLmRlcEV4cG9ydHMgPSBbXTsKICAgICAgICAgICAgdGhpcy5kZXBNYXBzID0gW107CiAgICAgICAgICAgIHRoaXMuZGVwTWF0Y2hlZCA9IFtdOwogICAgICAgICAgICB0aGlzLnBsdWdpbk1hcHMgPSB7fTsKICAgICAgICAgICAgdGhpcy5kZXBDb3VudCA9IDA7CgogICAgICAgICAgICAvKiB0aGlzLmV4cG9ydHMgdGhpcy5mYWN0b3J5CiAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwcyA9IFtdLAogICAgICAgICAgICAgICB0aGlzLmVuYWJsZWQsIHRoaXMuZmV0Y2hlZAogICAgICAgICAgICAqLwogICAgICAgIH07CgogICAgICAgIE1vZHVsZS5wcm90b3R5cGUgPSB7CiAgICAgICAgICAgIGluaXQ6IGZ1bmN0aW9uIChkZXBNYXBzLCBmYWN0b3J5LCBlcnJiYWNrLCBvcHRpb25zKSB7CiAgICAgICAgICAgICAgICBvcHRpb25zID0gb3B0aW9ucyB8fCB7fTsKCiAgICAgICAgICAgICAgICAvL0RvIG5vdCBkbyBtb3JlIGluaXRzIGlmIGFscmVhZHkgZG9uZS4gQ2FuIGhhcHBlbiBpZiB0aGVyZQogICAgICAgICAgICAgICAgLy9hcmUgbXVsdGlwbGUgZGVmaW5lIGNhbGxzIGZvciB0aGUgc2FtZSBtb2R1bGUuIFRoYXQgaXMgbm90CiAgICAgICAgICAgICAgICAvL2Egbm9ybWFsLCBjb21tb24gY2FzZSwgYnV0IGl0IGlzIGFsc28gbm90IHVuZXhwZWN0ZWQuCiAgICAgICAgICAgICAgICBpZiAodGhpcy5pbml0ZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgdGhpcy5mYWN0b3J5ID0gZmFjdG9yeTsKCiAgICAgICAgICAgICAgICBpZiAoZXJyYmFjaykgewogICAgICAgICAgICAgICAgICAgIC8vUmVnaXN0ZXIgZm9yIGVycm9ycyBvbiB0aGlzIG1vZHVsZS4KICAgICAgICAgICAgICAgICAgICB0aGlzLm9uKCdlcnJvcicsIGVycmJhY2spOwogICAgICAgICAgICAgICAgfSBlbHNlIGlmICh0aGlzLmV2ZW50cy5lcnJvcikgewogICAgICAgICAgICAgICAgICAgIC8vSWYgbm8gZXJyYmFjayBhbHJlYWR5LCBidXQgdGhlcmUgYXJlIGVycm9yIGxpc3RlbmVycwogICAgICAgICAgICAgICAgICAgIC8vb24gdGhpcyBtb2R1bGUsIHNldCB1cCBhbiBlcnJiYWNrIHRvIHBhc3MgdG8gdGhlIGRlcHMuCiAgICAgICAgICAgICAgICAgICAgZXJyYmFjayA9IGJpbmQodGhpcywgZnVuY3Rpb24gKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL0RvIGEgY29weSBvZiB0aGUgZGVwZW5kZW5jeSBhcnJheSwgc28gdGhhdAogICAgICAgICAgICAgICAgLy9zb3VyY2UgaW5wdXRzIGFyZSBub3QgbW9kaWZpZWQuIEZvciBleGFtcGxlCiAgICAgICAgICAgICAgICAvLyJzaGltIiBkZXBzIGFyZSBwYXNzZWQgaW4gaGVyZSBkaXJlY3RseSwgYW5kCiAgICAgICAgICAgICAgICAvL2RvaW5nIGEgZGlyZWN0IG1vZGlmaWNhdGlvbiBvZiB0aGUgZGVwTWFwcyBhcnJheQogICAgICAgICAgICAgICAgLy93b3VsZCBhZmZlY3QgdGhhdCBjb25maWcuCiAgICAgICAgICAgICAgICB0aGlzLmRlcE1hcHMgPSBkZXBNYXBzICYmIGRlcE1hcHMuc2xpY2UoMCk7CgogICAgICAgICAgICAgICAgdGhpcy5lcnJiYWNrID0gZXJyYmFjazsKCiAgICAgICAgICAgICAgICAvL0luZGljYXRlIHRoaXMgbW9kdWxlIGhhcyBiZSBpbml0aWFsaXplZAogICAgICAgICAgICAgICAgdGhpcy5pbml0ZWQgPSB0cnVlOwoKICAgICAgICAgICAgICAgIHRoaXMuaWdub3JlID0gb3B0aW9ucy5pZ25vcmU7CgogICAgICAgICAgICAgICAgLy9Db3VsZCBoYXZlIG9wdGlvbiB0byBpbml0IHRoaXMgbW9kdWxlIGluIGVuYWJsZWQgbW9kZSwKICAgICAgICAgICAgICAgIC8vb3IgY291bGQgaGF2ZSBiZWVuIHByZXZpb3VzbHkgbWFya2VkIGFzIGVuYWJsZWQuIEhvd2V2ZXIsCiAgICAgICAgICAgICAgICAvL3RoZSBkZXBlbmRlbmNpZXMgYXJlIG5vdCBrbm93biB1bnRpbCBpbml0IGlzIGNhbGxlZC4gU28KICAgICAgICAgICAgICAgIC8vaWYgZW5hYmxlZCBwcmV2aW91c2x5LCBub3cgdHJpZ2dlciBkZXBlbmRlbmNpZXMgYXMgZW5hYmxlZC4KICAgICAgICAgICAgICAgIGlmIChvcHRpb25zLmVuYWJsZWQgfHwgdGhpcy5lbmFibGVkKSB7CiAgICAgICAgICAgICAgICAgICAgLy9FbmFibGUgdGhpcyBtb2R1bGUgYW5kIGRlcGVuZGVuY2llcy4KICAgICAgICAgICAgICAgICAgICAvL1dpbGwgY2FsbCB0aGlzLmNoZWNrKCkKICAgICAgICAgICAgICAgICAgICB0aGlzLmVuYWJsZSgpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmNoZWNrKCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICBkZWZpbmVEZXA6IGZ1bmN0aW9uIChpLCBkZXBFeHBvcnRzKSB7CiAgICAgICAgICAgICAgICAvL0JlY2F1c2Ugb2YgY3ljbGVzLCBkZWZpbmVkIGNhbGxiYWNrIGZvciBhIGdpdmVuCiAgICAgICAgICAgICAgICAvL2V4cG9ydCBjYW4gYmUgY2FsbGVkIG1vcmUgdGhhbiBvbmNlLgogICAgICAgICAgICAgICAgaWYgKCF0aGlzLmRlcE1hdGNoZWRbaV0pIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcE1hdGNoZWRbaV0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIHRoaXMuZGVwQ291bnQgLT0gMTsKICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcEV4cG9ydHNbaV0gPSBkZXBFeHBvcnRzOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAoKICAgICAgICAgICAgZmV0Y2g6IGZ1bmN0aW9uICgpIHsKICAgICAgICAgICAgICAgIGlmICh0aGlzLmZldGNoZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB0aGlzLmZldGNoZWQgPSB0cnVlOwoKICAgICAgICAgICAgICAgIGNvbnRleHQuc3RhcnRUaW1lID0gKG5ldyBEYXRlKCkpLmdldFRpbWUoKTsKCiAgICAgICAgICAgICAgICB2YXIgbWFwID0gdGhpcy5tYXA7CgogICAgICAgICAgICAgICAgLy9JZiB0aGUgbWFuYWdlciBpcyBmb3IgYSBwbHVnaW4gbWFuYWdlZCByZXNvdXJjZSwKICAgICAgICAgICAgICAgIC8vYXNrIHRoZSBwbHVnaW4gdG8gbG9hZCBpdCBub3cuCiAgICAgICAgICAgICAgICBpZiAodGhpcy5zaGltKSB7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5tYWtlUmVxdWlyZSh0aGlzLm1hcCwgewogICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVCdWlsZENhbGxiYWNrOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgfSkodGhpcy5zaGltLmRlcHMgfHwgW10sIGJpbmQodGhpcywgZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gbWFwLnByZWZpeCA/IHRoaXMuY2FsbFBsdWdpbigpIDogdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICAgICAgfSkpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAvL1JlZ3VsYXIgZGVwZW5kZW5jeS4KICAgICAgICAgICAgICAgICAgICByZXR1cm4gbWFwLnByZWZpeCA/IHRoaXMuY2FsbFBsdWdpbigpIDogdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICBsb2FkOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICB2YXIgdXJsID0gdGhpcy5tYXAudXJsOwoKICAgICAgICAgICAgICAgIC8vUmVndWxhciBkZXBlbmRlbmN5LgogICAgICAgICAgICAgICAgaWYgKCF1cmxGZXRjaGVkW3VybF0pIHsKICAgICAgICAgICAgICAgICAgICB1cmxGZXRjaGVkW3VybF0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIGNvbnRleHQubG9hZCh0aGlzLm1hcC5pZCwgdXJsKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIC8qKgogICAgICAgICAgICAgKiBDaGVja3MgaWYgdGhlIG1vZHVsZSBpcyByZWFkeSB0byBkZWZpbmUgaXRzZWxmLCBhbmQgaWYgc28sCiAgICAgICAgICAgICAqIGRlZmluZSBpdC4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGNoZWNrOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICBpZiAoIXRoaXMuZW5hYmxlZCB8fCB0aGlzLmVuYWJsaW5nKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHZhciBlcnIsIGNqc01vZHVsZSwKICAgICAgICAgICAgICAgICAgICBpZCA9IHRoaXMubWFwLmlkLAogICAgICAgICAgICAgICAgICAgIGRlcEV4cG9ydHMgPSB0aGlzLmRlcEV4cG9ydHMsCiAgICAgICAgICAgICAgICAgICAgZXhwb3J0cyA9IHRoaXMuZXhwb3J0cywKICAgICAgICAgICAgICAgICAgICBmYWN0b3J5ID0gdGhpcy5mYWN0b3J5OwoKICAgICAgICAgICAgICAgIGlmICghdGhpcy5pbml0ZWQpIHsKICAgICAgICAgICAgICAgICAgICAvLyBPbmx5IGZldGNoIGlmIG5vdCBhbHJlYWR5IGluIHRoZSBkZWZRdWV1ZS4KICAgICAgICAgICAgICAgICAgICBpZiAoIWhhc1Byb3AoY29udGV4dC5kZWZRdWV1ZU1hcCwgaWQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuZmV0Y2goKTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgdGhpcy5lcnJvcik7CiAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKCF0aGlzLmRlZmluaW5nKSB7CiAgICAgICAgICAgICAgICAgICAgLy9UaGUgZmFjdG9yeSBjb3VsZCB0cmlnZ2VyIGFub3RoZXIgcmVxdWlyZSBjYWxsCiAgICAgICAgICAgICAgICAgICAgLy90aGF0IHdvdWxkIHJlc3VsdCBpbiBjaGVja2luZyB0aGlzIG1vZHVsZSB0bwogICAgICAgICAgICAgICAgICAgIC8vZGVmaW5lIGl0c2VsZiBhZ2Fpbi4gSWYgYWxyZWFkeSBpbiB0aGUgcHJvY2VzcwogICAgICAgICAgICAgICAgICAgIC8vb2YgZG9pbmcgdGhhdCwgc2tpcCB0aGlzIHdvcmsuCiAgICAgICAgICAgICAgICAgICAgdGhpcy5kZWZpbmluZyA9IHRydWU7CgogICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLmRlcENvdW50IDwgMSAmJiAhdGhpcy5kZWZpbmVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChpc0Z1bmN0aW9uKGZhY3RvcnkpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0cnkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBjb250ZXh0LmV4ZWNDYihpZCwgZmFjdG9yeSwgZGVwRXhwb3J0cywgZXhwb3J0cyk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZXJyID0gZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBGYXZvciByZXR1cm4gdmFsdWUgb3ZlciBleHBvcnRzLiBJZiBub2RlL2NqcyBpbiBwbGF5LAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdGhlbiB3aWxsIG5vdCBoYXZlIGEgcmV0dXJuIHZhbHVlIGFueXdheS4gRmF2b3IKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIG1vZHVsZS5leHBvcnRzIGFzc2lnbm1lbnQgb3ZlciBleHBvcnRzIG9iamVjdC4KICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLm1hcC5pc0RlZmluZSAmJiBleHBvcnRzID09PSB1bmRlZmluZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjanNNb2R1bGUgPSB0aGlzLm1vZHVsZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAoY2pzTW9kdWxlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBjanNNb2R1bGUuZXhwb3J0czsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMudXNpbmdFeHBvcnRzKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vZXhwb3J0cyBhbHJlYWR5IHNldCB0aGUgZGVmaW5lZCB2YWx1ZS4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZXhwb3J0cyA9IHRoaXMuZXhwb3J0czsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIElmIHRoZXJlIGlzIGFuIGVycm9yIGxpc3RlbmVyLCBmYXZvciBwYXNzaW5nCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdG8gdGhhdCBpbnN0ZWFkIG9mIHRocm93aW5nIGFuIGVycm9yLiBIb3dldmVyLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIG9ubHkgZG8gaXQgZm9yIGRlZmluZSgpJ2QgIG1vZHVsZXMuIHJlcXVpcmUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBlcnJiYWNrcyBzaG91bGQgbm90IGJlIGNhbGxlZCBmb3IgZmFpbHVyZXMgaW4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyB0aGVpciBjYWxsYmFja3MgKCM2OTkpLiBIb3dldmVyIGlmIGEgZ2xvYmFsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gb25FcnJvciBpcyBzZXQsIHVzZSB0aGF0LgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICgodGhpcy5ldmVudHMuZXJyb3IgJiYgdGhpcy5tYXAuaXNEZWZpbmUpIHx8CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlcS5vbkVycm9yICE9PSBkZWZhdWx0T25FcnJvcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZU1hcCA9IHRoaXMubWFwOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZU1vZHVsZXMgPSB0aGlzLm1hcC5pc0RlZmluZSA/IFt0aGlzLm1hcC5pZF0gOiBudWxsOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZVR5cGUgPSB0aGlzLm1hcC5pc0RlZmluZSA/ICdkZWZpbmUnIDogJ3JlcXVpcmUnOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcigodGhpcy5lcnJvciA9IGVycikpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0gZWxzZSBpZiAodHlwZW9mIGNvbnNvbGUgIT09ICd1bmRlZmluZWQnICYmCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjb25zb2xlLmVycm9yKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIExvZyB0aGUgZXJyb3IgZm9yIGRlYnVnZ2luZy4gSWYgcHJvbWlzZXMgY291bGQgYmUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdXNlZCwgdGhpcyB3b3VsZCBiZSBkaWZmZXJlbnQsIGJ1dCBtYWtpbmcgZG8uCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbnNvbGUuZXJyb3IoZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBEbyBub3Qgd2FudCB0byBjb21wbGV0ZWx5IGxvc2UgdGhlIGVycm9yLiBXaGlsZSB0aGlzCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIHdpbGwgbWVzcyB1cCBwcm9jZXNzaW5nIGFuZCBsZWFkIHRvIHNpbWlsYXIgcmVzdWx0cwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBhcyBidWcgMTQ0MCwgaXQgYXQgbGVhc3Qgc3VyZmFjZXMgdGhlIGVycm9yLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXEub25FcnJvcihlcnIpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vSnVzdCBhIGxpdGVyYWwgdmFsdWUKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBmYWN0b3J5OwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmV4cG9ydHMgPSBleHBvcnRzOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRoaXMubWFwLmlzRGVmaW5lICYmICF0aGlzLmlnbm9yZSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgZGVmaW5lZFtpZF0gPSBleHBvcnRzOwoKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmIChyZXEub25SZXNvdXJjZUxvYWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB2YXIgcmVzTG9hZE1hcHMgPSBbXTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlYWNoKHRoaXMuZGVwTWFwcywgZnVuY3Rpb24gKGRlcE1hcCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXNMb2FkTWFwcy5wdXNoKGRlcE1hcC5ub3JtYWxpemVkTWFwIHx8IGRlcE1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcmVxLm9uUmVzb3VyY2VMb2FkKGNvbnRleHQsIHRoaXMubWFwLCByZXNMb2FkTWFwcyk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vQ2xlYW4gdXAKICAgICAgICAgICAgICAgICAgICAgICAgY2xlYW5SZWdpc3RyeShpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgLy9GaW5pc2hlZCB0aGUgZGVmaW5lIHN0YWdlLiBBbGxvdyBjYWxsaW5nIGNoZWNrIGFnYWluCiAgICAgICAgICAgICAgICAgICAgLy90byBhbGxvdyBkZWZpbmUgbm90aWZpY2F0aW9ucyBiZWxvdyBpbiB0aGUgY2FzZSBvZiBhCiAgICAgICAgICAgICAgICAgICAgLy9jeWNsZS4KICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluaW5nID0gZmFsc2U7CgogICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLmRlZmluZWQgJiYgIXRoaXMuZGVmaW5lRW1pdHRlZCkgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZUVtaXR0ZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2RlZmluZWQnLCB0aGlzLmV4cG9ydHMpOwogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZUVtaXRDb21wbGV0ZSA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIGNhbGxQbHVnaW46IGZ1bmN0aW9uICgpIHsKICAgICAgICAgICAgICAgIHZhciBtYXAgPSB0aGlzLm1hcCwKICAgICAgICAgICAgICAgICAgICBpZCA9IG1hcC5pZCwKICAgICAgICAgICAgICAgICAgICAvL01hcCBhbHJlYWR5IG5vcm1hbGl6ZWQgdGhlIHByZWZpeC4KICAgICAgICAgICAgICAgICAgICBwbHVnaW5NYXAgPSBtYWtlTW9kdWxlTWFwKG1hcC5wcmVmaXgpOwoKICAgICAgICAgICAgICAgIC8vTWFyayB0aGlzIGFzIGEgZGVwZW5kZW5jeSBmb3IgdGhpcyBwbHVnaW4sIHNvIGl0CiAgICAgICAgICAgICAgICAvL2NhbiBiZSB0cmFjZWQgZm9yIGN5Y2xlcy4KICAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwcy5wdXNoKHBsdWdpbk1hcCk7CgogICAgICAgICAgICAgICAgb24ocGx1Z2luTWFwLCAnZGVmaW5lZCcsIGJpbmQodGhpcywgZnVuY3Rpb24gKHBsdWdpbikgewogICAgICAgICAgICAgICAgICAgIHZhciBsb2FkLCBub3JtYWxpemVkTWFwLCBub3JtYWxpemVkTW9kLAogICAgICAgICAgICAgICAgICAgICAgICBidW5kbGVJZCA9IGdldE93bihidW5kbGVzTWFwLCB0aGlzLm1hcC5pZCksCiAgICAgICAgICAgICAgICAgICAgICAgIG5hbWUgPSB0aGlzLm1hcC5uYW1lLAogICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnROYW1lID0gdGhpcy5tYXAucGFyZW50TWFwID8gdGhpcy5tYXAucGFyZW50TWFwLm5hbWUgOiBudWxsLAogICAgICAgICAgICAgICAgICAgICAgICBsb2NhbFJlcXVpcmUgPSBjb250ZXh0Lm1ha2VSZXF1aXJlKG1hcC5wYXJlbnRNYXAsIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGVuYWJsZUJ1aWxkQ2FsbGJhY2s6IHRydWUKICAgICAgICAgICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgICAgIC8vSWYgY3VycmVudCBtYXAgaXMgbm90IG5vcm1hbGl6ZWQsIHdhaXQgZm9yIHRoYXQKICAgICAgICAgICAgICAgICAgICAvL25vcm1hbGl6ZWQgbmFtZSB0byBsb2FkIGluc3RlYWQgb2YgY29udGludWluZy4KICAgICAgICAgICAgICAgICAgICBpZiAodGhpcy5tYXAudW5ub3JtYWxpemVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplIHRoZSBJRCBpZiB0aGUgcGx1Z2luIGFsbG93cyBpdC4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHBsdWdpbi5ub3JtYWxpemUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG5hbWUgPSBwbHVnaW4ubm9ybWFsaXplKG5hbWUsIGZ1bmN0aW9uIChuYW1lKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCB0cnVlKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0pIHx8ICcnOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL3ByZWZpeCBhbmQgbmFtZSBzaG91bGQgYWxyZWFkeSBiZSBub3JtYWxpemVkLCBubyBuZWVkCiAgICAgICAgICAgICAgICAgICAgICAgIC8vZm9yIGFwcGx5aW5nIG1hcCBjb25maWcgYWdhaW4gZWl0aGVyLgogICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTWFwID0gbWFrZU1vZHVsZU1hcChtYXAucHJlZml4ICsgJyEnICsgbmFtZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5tYXAucGFyZW50TWFwKTsKICAgICAgICAgICAgICAgICAgICAgICAgb24obm9ybWFsaXplZE1hcCwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICdkZWZpbmVkJywgYmluZCh0aGlzLCBmdW5jdGlvbiAodmFsdWUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLm1hcC5ub3JtYWxpemVkTWFwID0gbm9ybWFsaXplZE1hcDsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmluaXQoW10sIGZ1bmN0aW9uICgpIHsgcmV0dXJuIHZhbHVlOyB9LCBudWxsLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGVuYWJsZWQ6IHRydWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlnbm9yZTogdHJ1ZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfSkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZE1vZCA9IGdldE93bihyZWdpc3RyeSwgbm9ybWFsaXplZE1hcC5pZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChub3JtYWxpemVkTW9kKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL01hcmsgdGhpcyBhcyBhIGRlcGVuZGVuY3kgZm9yIHRoaXMgcGx1Z2luLCBzbyBpdAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9jYW4gYmUgdHJhY2VkIGZvciBjeWNsZXMuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcE1hcHMucHVzaChub3JtYWxpemVkTWFwKTsKCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAodGhpcy5ldmVudHMuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTW9kLm9uKCdlcnJvcicsIGJpbmQodGhpcywgZnVuY3Rpb24gKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9KSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTW9kLmVuYWJsZSgpOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAvL0lmIGEgcGF0aHMgY29uZmlnLCB0aGVuIGp1c3QgbG9hZCB0aGF0IGZpbGUgaW5zdGVhZCB0bwogICAgICAgICAgICAgICAgICAgIC8vcmVzb2x2ZSB0aGUgcGx1Z2luLCBhcyBpdCBpcyBidWlsdCBpbnRvIHRoYXQgcGF0aHMgbGF5ZXIuCiAgICAgICAgICAgICAgICAgICAgaWYgKGJ1bmRsZUlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMubWFwLnVybCA9IGNvbnRleHQubmFtZVRvVXJsKGJ1bmRsZUlkKTsKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybjsKICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGxvYWQgPSBiaW5kKHRoaXMsIGZ1bmN0aW9uICh2YWx1ZSkgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmluaXQoW10sIGZ1bmN0aW9uICgpIHsgcmV0dXJuIHZhbHVlOyB9LCBudWxsLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVkOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgICAgICBsb2FkLmVycm9yID0gYmluZCh0aGlzLCBmdW5jdGlvbiAoZXJyKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuaW5pdGVkID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5lcnJvciA9IGVycjsKICAgICAgICAgICAgICAgICAgICAgICAgZXJyLnJlcXVpcmVNb2R1bGVzID0gW2lkXTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vUmVtb3ZlIHRlbXAgdW5ub3JtYWxpemVkIG1vZHVsZXMgZm9yIHRoaXMgbW9kdWxlLAogICAgICAgICAgICAgICAgICAgICAgICAvL3NpbmNlIHRoZXkgd2lsbCBuZXZlciBiZSByZXNvbHZlZCBvdGhlcndpc2Ugbm93LgogICAgICAgICAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZC5tYXAuaWQuaW5kZXhPZihpZCArICdfdW5ub3JtYWxpemVkJykgPT09IDApIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjbGVhblJlZ2lzdHJ5KG1vZC5tYXAuaWQpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgICAgIG9uRXJyb3IoZXJyKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgLy9BbGxvdyBwbHVnaW5zIHRvIGxvYWQgb3RoZXIgY29kZSB3aXRob3V0IGhhdmluZyB0byBrbm93IHRoZQogICAgICAgICAgICAgICAgICAgIC8vY29udGV4dCBvciBob3cgdG8gJ2NvbXBsZXRlJyB0aGUgbG9hZC4KICAgICAgICAgICAgICAgICAgICBsb2FkLmZyb21UZXh0ID0gYmluZCh0aGlzLCBmdW5jdGlvbiAodGV4dCwgdGV4dEFsdCkgewogICAgICAgICAgICAgICAgICAgICAgICAvKmpzbGludCBldmlsOiB0cnVlICovCiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBtb2R1bGVOYW1lID0gbWFwLm5hbWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVNYXAgPSBtYWtlTW9kdWxlTWFwKG1vZHVsZU5hbWUpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaGFzSW50ZXJhY3RpdmUgPSB1c2VJbnRlcmFjdGl2ZTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vQXMgb2YgMi4xLjAsIHN1cHBvcnQganVzdCBwYXNzaW5nIHRoZSB0ZXh0LCB0byByZWluZm9yY2UKICAgICAgICAgICAgICAgICAgICAgICAgLy9mcm9tVGV4dCBvbmx5IGJlaW5nIGNhbGxlZCBvbmNlIHBlciByZXNvdXJjZS4gU3RpbGwKICAgICAgICAgICAgICAgICAgICAgICAgLy9zdXBwb3J0IG9sZCBzdHlsZSBvZiBwYXNzaW5nIG1vZHVsZU5hbWUgYnV0IGRpc2NhcmQKICAgICAgICAgICAgICAgICAgICAgICAgLy90aGF0IG1vZHVsZU5hbWUgaW4gZmF2b3Igb2YgdGhlIGludGVybmFsIHJlZi4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRleHRBbHQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHRleHQgPSB0ZXh0QWx0OwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL1R1cm4gb2ZmIGludGVyYWN0aXZlIHNjcmlwdCBtYXRjaGluZyBmb3IgSUUgZm9yIGFueSBkZWZpbmUKICAgICAgICAgICAgICAgICAgICAgICAgLy9jYWxscyBpbiB0aGUgdGV4dCwgdGhlbiB0dXJuIGl0IGJhY2sgb24gYXQgdGhlIGVuZC4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGhhc0ludGVyYWN0aXZlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB1c2VJbnRlcmFjdGl2ZSA9IGZhbHNlOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL1ByaW1lIHRoZSBzeXN0ZW0gYnkgY3JlYXRpbmcgYSBtb2R1bGUgaW5zdGFuY2UgZm9yCiAgICAgICAgICAgICAgICAgICAgICAgIC8vaXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGdldE1vZHVsZShtb2R1bGVNYXApOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9UcmFuc2ZlciBhbnkgY29uZmlnIHRvIHRoaXMgb3RoZXIgbW9kdWxlLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFzUHJvcChjb25maWcuY29uZmlnLCBpZCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZy5jb25maWdbbW9kdWxlTmFtZV0gPSBjb25maWcuY29uZmlnW2lkXTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgdHJ5IHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlcS5leGVjKHRleHQpOwogICAgICAgICAgICAgICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ2Zyb210ZXh0ZXZhbCcsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICdmcm9tVGV4dCBldmFsIGZvciAnICsgaWQgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICcgZmFpbGVkOiAnICsgZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgW2lkXSkpOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFzSW50ZXJhY3RpdmUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHVzZUludGVyYWN0aXZlID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9NYXJrIHRoaXMgYXMgYSBkZXBlbmRlbmN5IGZvciB0aGUgcGx1Z2luCiAgICAgICAgICAgICAgICAgICAgICAgIC8vcmVzb3VyY2UKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBNYXBzLnB1c2gobW9kdWxlTWFwKTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vU3VwcG9ydCBhbm9ueW1vdXMgbW9kdWxlcy4KICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5jb21wbGV0ZUxvYWQobW9kdWxlTmFtZSk7CgogICAgICAgICAgICAgICAgICAgICAgICAvL0JpbmQgdGhlIHZhbHVlIG9mIHRoYXQgbW9kdWxlIHRvIHRoZSB2YWx1ZSBmb3IgdGhpcwogICAgICAgICAgICAgICAgICAgICAgICAvL3Jlc291cmNlIElELgogICAgICAgICAgICAgICAgICAgICAgICBsb2NhbFJlcXVpcmUoW21vZHVsZU5hbWVdLCBsb2FkKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgLy9Vc2UgcGFyZW50TmFtZSBoZXJlIHNpbmNlIHRoZSBwbHVnaW4ncyBuYW1lIGlzIG5vdCByZWxpYWJsZSwKICAgICAgICAgICAgICAgICAgICAvL2NvdWxkIGJlIHNvbWUgd2VpcmQgc3RyaW5nIHdpdGggbm8gcGF0aCB0aGF0IGFjdHVhbGx5IHdhbnRzIHRvCiAgICAgICAgICAgICAgICAgICAgLy9yZWZlcmVuY2UgdGhlIHBhcmVudE5hbWUncyBwYXRoLgogICAgICAgICAgICAgICAgICAgIHBsdWdpbi5sb2FkKG1hcC5uYW1lLCBsb2NhbFJlcXVpcmUsIGxvYWQsIGNvbmZpZyk7CiAgICAgICAgICAgICAgICB9KSk7CgogICAgICAgICAgICAgICAgY29udGV4dC5lbmFibGUocGx1Z2luTWFwLCB0aGlzKTsKICAgICAgICAgICAgICAgIHRoaXMucGx1Z2luTWFwc1twbHVnaW5NYXAuaWRdID0gcGx1Z2luTWFwOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgZW5hYmxlOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICBlbmFibGVkUmVnaXN0cnlbdGhpcy5tYXAuaWRdID0gdGhpczsKICAgICAgICAgICAgICAgIHRoaXMuZW5hYmxlZCA9IHRydWU7CgogICAgICAgICAgICAgICAgLy9TZXQgZmxhZyBtZW50aW9uaW5nIHRoYXQgdGhlIG1vZHVsZSBpcyBlbmFibGluZywKICAgICAgICAgICAgICAgIC8vc28gdGhhdCBpbW1lZGlhdGUgY2FsbHMgdG8gdGhlIGRlZmluZWQgY2FsbGJhY2tzCiAgICAgICAgICAgICAgICAvL2ZvciBkZXBlbmRlbmNpZXMgZG8gbm90IHRyaWdnZXIgaW5hZHZlcnRlbnQgbG9hZAogICAgICAgICAgICAgICAgLy93aXRoIHRoZSBkZXBDb3VudCBzdGlsbCBiZWluZyB6ZXJvLgogICAgICAgICAgICAgICAgdGhpcy5lbmFibGluZyA9IHRydWU7CgogICAgICAgICAgICAgICAgLy9FbmFibGUgZWFjaCBkZXBlbmRlbmN5CiAgICAgICAgICAgICAgICBlYWNoKHRoaXMuZGVwTWFwcywgYmluZCh0aGlzLCBmdW5jdGlvbiAoZGVwTWFwLCBpKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIGlkLCBtb2QsIGhhbmRsZXI7CgogICAgICAgICAgICAgICAgICAgIGlmICh0eXBlb2YgZGVwTWFwID09PSAnc3RyaW5nJykgewogICAgICAgICAgICAgICAgICAgICAgICAvL0RlcGVuZGVuY3kgbmVlZHMgdG8gYmUgY29udmVydGVkIHRvIGEgZGVwTWFwCiAgICAgICAgICAgICAgICAgICAgICAgIC8vYW5kIHdpcmVkIHVwIHRvIHRoaXMgbW9kdWxlLgogICAgICAgICAgICAgICAgICAgICAgICBkZXBNYXAgPSBtYWtlTW9kdWxlTWFwKGRlcE1hcCwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAodGhpcy5tYXAuaXNEZWZpbmUgPyB0aGlzLm1hcCA6IHRoaXMubWFwLnBhcmVudE1hcCksCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZmFsc2UsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIXRoaXMuc2tpcE1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwc1tpXSA9IGRlcE1hcDsKCiAgICAgICAgICAgICAgICAgICAgICAgIGhhbmRsZXIgPSBnZXRPd24oaGFuZGxlcnMsIGRlcE1hcC5pZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFuZGxlcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBFeHBvcnRzW2ldID0gaGFuZGxlcih0aGlzKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybjsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBDb3VudCArPSAxOwoKICAgICAgICAgICAgICAgICAgICAgICAgb24oZGVwTWFwLCAnZGVmaW5lZCcsIGJpbmQodGhpcywgZnVuY3Rpb24gKGRlcEV4cG9ydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLnVuZGVmZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZURlcChpLCBkZXBFeHBvcnRzKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuY2hlY2soKTsKICAgICAgICAgICAgICAgICAgICAgICAgfSkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRoaXMuZXJyYmFjaykgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgb24oZGVwTWFwLCAnZXJyb3InLCBiaW5kKHRoaXMsIHRoaXMuZXJyYmFjaykpOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMuZXZlbnRzLmVycm9yKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBObyBkaXJlY3QgZXJyYmFjayBvbiB0aGlzIG1vZHVsZSwgYnV0IHNvbWV0aGluZwogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gZWxzZSBpcyBsaXN0ZW5pbmcgZm9yIGVycm9ycywgc28gYmUgc3VyZSB0bwogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gcHJvcGFnYXRlIHRoZSBlcnJvciBjb3JyZWN0bHkuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBvbihkZXBNYXAsICdlcnJvcicsIGJpbmQodGhpcywgZnVuY3Rpb24oZXJyKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5lbWl0KCdlcnJvcicsIGVycik7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9KSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGlkID0gZGVwTWFwLmlkOwogICAgICAgICAgICAgICAgICAgIG1vZCA9IHJlZ2lzdHJ5W2lkXTsKCiAgICAgICAgICAgICAgICAgICAgLy9Ta2lwIHNwZWNpYWwgbW9kdWxlcyBsaWtlICdyZXF1aXJlJywgJ2V4cG9ydHMnLCAnbW9kdWxlJwogICAgICAgICAgICAgICAgICAgIC8vQWxzbywgZG9uJ3QgY2FsbCBlbmFibGUgaWYgaXQgaXMgYWxyZWFkeSBlbmFibGVkLAogICAgICAgICAgICAgICAgICAgIC8vaW1wb3J0YW50IGluIGNpcmN1bGFyIGRlcGVuZGVuY3kgY2FzZXMuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFoYXNQcm9wKGhhbmRsZXJzLCBpZCkgJiYgbW9kICYmICFtb2QuZW5hYmxlZCkgewogICAgICAgICAgICAgICAgICAgICAgICBjb250ZXh0LmVuYWJsZShkZXBNYXAsIHRoaXMpOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pKTsKCiAgICAgICAgICAgICAgICAvL0VuYWJsZSBlYWNoIHBsdWdpbiB0aGF0IGlzIHVzZWQgaW4KICAgICAgICAgICAgICAgIC8vYSBkZXBlbmRlbmN5CiAgICAgICAgICAgICAgICBlYWNoUHJvcCh0aGlzLnBsdWdpbk1hcHMsIGJpbmQodGhpcywgZnVuY3Rpb24gKHBsdWdpbk1hcCkgewogICAgICAgICAgICAgICAgICAgIHZhciBtb2QgPSBnZXRPd24ocmVnaXN0cnksIHBsdWdpbk1hcC5pZCk7CiAgICAgICAgICAgICAgICAgICAgaWYgKG1vZCAmJiAhbW9kLmVuYWJsZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5lbmFibGUocGx1Z2luTWFwLCB0aGlzKTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9KSk7CgogICAgICAgICAgICAgICAgdGhpcy5lbmFibGluZyA9IGZhbHNlOwoKICAgICAgICAgICAgICAgIHRoaXMuY2hlY2soKTsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIG9uOiBmdW5jdGlvbiAobmFtZSwgY2IpIHsKICAgICAgICAgICAgICAgIHZhciBjYnMgPSB0aGlzLmV2ZW50c1tuYW1lXTsKICAgICAgICAgICAgICAgIGlmICghY2JzKSB7CiAgICAgICAgICAgICAgICAgICAgY2JzID0gdGhpcy5ldmVudHNbbmFtZV0gPSBbXTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNicy5wdXNoKGNiKTsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIGVtaXQ6IGZ1bmN0aW9uIChuYW1lLCBldnQpIHsKICAgICAgICAgICAgICAgIGVhY2godGhpcy5ldmVudHNbbmFtZV0sIGZ1bmN0aW9uIChjYikgewogICAgICAgICAgICAgICAgICAgIGNiKGV2dCk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIGlmIChuYW1lID09PSAnZXJyb3InKSB7CiAgICAgICAgICAgICAgICAgICAgLy9Ob3cgdGhhdCB0aGUgZXJyb3IgaGFuZGxlciB3YXMgdHJpZ2dlcmVkLCByZW1vdmUKICAgICAgICAgICAgICAgICAgICAvL3RoZSBsaXN0ZW5lcnMsIHNpbmNlIHRoaXMgYnJva2VuIE1vZHVsZSBpbnN0YW5jZQogICAgICAgICAgICAgICAgICAgIC8vY2FuIHN0YXkgYXJvdW5kIGZvciBhIHdoaWxlIGluIHRoZSByZWdpc3RyeS4KICAgICAgICAgICAgICAgICAgICBkZWxldGUgdGhpcy5ldmVudHNbbmFtZV07CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBmdW5jdGlvbiBjYWxsR2V0TW9kdWxlKGFyZ3MpIHsKICAgICAgICAgICAgLy9Ta2lwIG1vZHVsZXMgYWxyZWFkeSBkZWZpbmVkLgogICAgICAgICAgICBpZiAoIWhhc1Byb3AoZGVmaW5lZCwgYXJnc1swXSkpIHsKICAgICAgICAgICAgICAgIGdldE1vZHVsZShtYWtlTW9kdWxlTWFwKGFyZ3NbMF0sIG51bGwsIHRydWUpKS5pbml0KGFyZ3NbMV0sIGFyZ3NbMl0pOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiByZW1vdmVMaXN0ZW5lcihub2RlLCBmdW5jLCBuYW1lLCBpZU5hbWUpIHsKICAgICAgICAgICAgLy9GYXZvciBkZXRhY2hFdmVudCBiZWNhdXNlIG9mIElFOQogICAgICAgICAgICAvL2lzc3VlLCBzZWUgYXR0YWNoRXZlbnQvYWRkRXZlbnRMaXN0ZW5lciBjb21tZW50IGVsc2V3aGVyZQogICAgICAgICAgICAvL2luIHRoaXMgZmlsZS4KICAgICAgICAgICAgaWYgKG5vZGUuZGV0YWNoRXZlbnQgJiYgIWlzT3BlcmEpIHsKICAgICAgICAgICAgICAgIC8vUHJvYmFibHkgSUUuIElmIG5vdCBpdCB3aWxsIHRocm93IGFuIGVycm9yLCB3aGljaCB3aWxsIGJlCiAgICAgICAgICAgICAgICAvL3VzZWZ1bCB0byBrbm93LgogICAgICAgICAgICAgICAgaWYgKGllTmFtZSkgewogICAgICAgICAgICAgICAgICAgIG5vZGUuZGV0YWNoRXZlbnQoaWVOYW1lLCBmdW5jKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgIG5vZGUucmVtb3ZlRXZlbnRMaXN0ZW5lcihuYW1lLCBmdW5jLCBmYWxzZSk7CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIC8qKgogICAgICAgICAqIEdpdmVuIGFuIGV2ZW50IGZyb20gYSBzY3JpcHQgbm9kZSwgZ2V0IHRoZSByZXF1aXJlanMgaW5mbyBmcm9tIGl0LAogICAgICAgICAqIGFuZCB0aGVuIHJlbW92ZXMgdGhlIGV2ZW50IGxpc3RlbmVycyBvbiB0aGUgbm9kZS4KICAgICAgICAgKiBAcGFyYW0ge0V2ZW50fSBldnQKICAgICAgICAgKiBAcmV0dXJucyB7T2JqZWN0fQogICAgICAgICAqLwogICAgICAgIGZ1bmN0aW9uIGdldFNjcmlwdERhdGEoZXZ0KSB7CiAgICAgICAgICAgIC8vVXNpbmcgY3VycmVudFRhcmdldCBpbnN0ZWFkIG9mIHRhcmdldCBmb3IgRmlyZWZveCAyLjAncyBzYWtlLiBOb3QKICAgICAgICAgICAgLy9hbGwgb2xkIGJyb3dzZXJzIHdpbGwgYmUgc3VwcG9ydGVkLCBidXQgdGhpcyBvbmUgd2FzIGVhc3kgZW5vdWdoCiAgICAgICAgICAgIC8vdG8gc3VwcG9ydCBhbmQgc3RpbGwgbWFrZXMgc2Vuc2UuCiAgICAgICAgICAgIHZhciBub2RlID0gZXZ0LmN1cnJlbnRUYXJnZXQgfHwgZXZ0LnNyY0VsZW1lbnQ7CgogICAgICAgICAgICAvL1JlbW92ZSB0aGUgbGlzdGVuZXJzIG9uY2UgaGVyZS4KICAgICAgICAgICAgcmVtb3ZlTGlzdGVuZXIobm9kZSwgY29udGV4dC5vblNjcmlwdExvYWQsICdsb2FkJywgJ29ucmVhZHlzdGF0ZWNoYW5nZScpOwogICAgICAgICAgICByZW1vdmVMaXN0ZW5lcihub2RlLCBjb250ZXh0Lm9uU2NyaXB0RXJyb3IsICdlcnJvcicpOwoKICAgICAgICAgICAgcmV0dXJuIHsKICAgICAgICAgICAgICAgIG5vZGU6IG5vZGUsCiAgICAgICAgICAgICAgICBpZDogbm9kZSAmJiBub2RlLmdldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlbW9kdWxlJykKICAgICAgICAgICAgfTsKICAgICAgICB9CgogICAgICAgIGZ1bmN0aW9uIGludGFrZURlZmluZXMoKSB7CiAgICAgICAgICAgIHZhciBhcmdzOwoKICAgICAgICAgICAgLy9BbnkgZGVmaW5lZCBtb2R1bGVzIGluIHRoZSBnbG9iYWwgcXVldWUsIGludGFrZSB0aGVtIG5vdy4KICAgICAgICAgICAgdGFrZUdsb2JhbFF1ZXVlKCk7CgogICAgICAgICAgICAvL01ha2Ugc3VyZSBhbnkgcmVtYWluaW5nIGRlZlF1ZXVlIGl0ZW1zIGdldCBwcm9wZXJseSBwcm9jZXNzZWQuCiAgICAgICAgICAgIHdoaWxlIChkZWZRdWV1ZS5sZW5ndGgpIHsKICAgICAgICAgICAgICAgIGFyZ3MgPSBkZWZRdWV1ZS5zaGlmdCgpOwogICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IG51bGwpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ21pc21hdGNoJywgJ01pc21hdGNoZWQgYW5vbnltb3VzIGRlZmluZSgpIG1vZHVsZTogJyArCiAgICAgICAgICAgICAgICAgICAgICAgIGFyZ3NbYXJncy5sZW5ndGggLSAxXSkpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAvL2FyZ3MgYXJlIGlkLCBkZXBzLCBmYWN0b3J5LiBTaG91bGQgYmUgbm9ybWFsaXplZCBieSB0aGUKICAgICAgICAgICAgICAgICAgICAvL2RlZmluZSgpIGZ1bmN0aW9uLgogICAgICAgICAgICAgICAgICAgIGNhbGxHZXRNb2R1bGUoYXJncyk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZU1hcCA9IHt9OwogICAgICAgIH0KCiAgICAgICAgY29udGV4dCA9IHsKICAgICAgICAgICAgY29uZmlnOiBjb25maWcsCiAgICAgICAgICAgIGNvbnRleHROYW1lOiBjb250ZXh0TmFtZSwKICAgICAgICAgICAgcmVnaXN0cnk6IHJlZ2lzdHJ5LAogICAgICAgICAgICBkZWZpbmVkOiBkZWZpbmVkLAogICAgICAgICAgICB1cmxGZXRjaGVkOiB1cmxGZXRjaGVkLAogICAgICAgICAgICBkZWZRdWV1ZTogZGVmUXVldWUsCiAgICAgICAgICAgIGRlZlF1ZXVlTWFwOiB7fSwKICAgICAgICAgICAgTW9kdWxlOiBNb2R1bGUsCiAgICAgICAgICAgIG1ha2VNb2R1bGVNYXA6IG1ha2VNb2R1bGVNYXAsCiAgICAgICAgICAgIG5leHRUaWNrOiByZXEubmV4dFRpY2ssCiAgICAgICAgICAgIG9uRXJyb3I6IG9uRXJyb3IsCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogU2V0IGEgY29uZmlndXJhdGlvbiBmb3IgdGhlIGNvbnRleHQuCiAgICAgICAgICAgICAqIEBwYXJhbSB7T2JqZWN0fSBjZmcgY29uZmlnIG9iamVjdCB0byBpbnRlZ3JhdGUuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBjb25maWd1cmU6IGZ1bmN0aW9uIChjZmcpIHsKICAgICAgICAgICAgICAgIC8vTWFrZSBzdXJlIHRoZSBiYXNlVXJsIGVuZHMgaW4gYSBzbGFzaC4KICAgICAgICAgICAgICAgIGlmIChjZmcuYmFzZVVybCkgewogICAgICAgICAgICAgICAgICAgIGlmIChjZmcuYmFzZVVybC5jaGFyQXQoY2ZnLmJhc2VVcmwubGVuZ3RoIC0gMSkgIT09ICcvJykgewogICAgICAgICAgICAgICAgICAgICAgICBjZmcuYmFzZVVybCArPSAnLyc7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vU2F2ZSBvZmYgdGhlIHBhdGhzIHNpbmNlIHRoZXkgcmVxdWlyZSBzcGVjaWFsIHByb2Nlc3NpbmcsCiAgICAgICAgICAgICAgICAvL3RoZXkgYXJlIGFkZGl0aXZlLgogICAgICAgICAgICAgICAgdmFyIHNoaW0gPSBjb25maWcuc2hpbSwKICAgICAgICAgICAgICAgICAgICBvYmpzID0gewogICAgICAgICAgICAgICAgICAgICAgICBwYXRoczogdHJ1ZSwKICAgICAgICAgICAgICAgICAgICAgICAgYnVuZGxlczogdHJ1ZSwKICAgICAgICAgICAgICAgICAgICAgICAgY29uZmlnOiB0cnVlLAogICAgICAgICAgICAgICAgICAgICAgICBtYXA6IHRydWUKICAgICAgICAgICAgICAgICAgICB9OwoKICAgICAgICAgICAgICAgIGVhY2hQcm9wKGNmZywgZnVuY3Rpb24gKHZhbHVlLCBwcm9wKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKG9ianNbcHJvcF0pIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKCFjb25maWdbcHJvcF0pIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZ1twcm9wXSA9IHt9OwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIG1peGluKGNvbmZpZ1twcm9wXSwgdmFsdWUsIHRydWUsIHRydWUpOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZ1twcm9wXSA9IHZhbHVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgIC8vUmV2ZXJzZSBtYXAgdGhlIGJ1bmRsZXMKICAgICAgICAgICAgICAgIGlmIChjZmcuYnVuZGxlcykgewogICAgICAgICAgICAgICAgICAgIGVhY2hQcm9wKGNmZy5idW5kbGVzLCBmdW5jdGlvbiAodmFsdWUsIHByb3ApIHsKICAgICAgICAgICAgICAgICAgICAgICAgZWFjaCh2YWx1ZSwgZnVuY3Rpb24gKHYpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh2ICE9PSBwcm9wKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYnVuZGxlc01hcFt2XSA9IHByb3A7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vTWVyZ2Ugc2hpbQogICAgICAgICAgICAgICAgaWYgKGNmZy5zaGltKSB7CiAgICAgICAgICAgICAgICAgICAgZWFjaFByb3AoY2ZnLnNoaW0sIGZ1bmN0aW9uICh2YWx1ZSwgaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9Ob3JtYWxpemUgdGhlIHN0cnVjdHVyZQogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaXNBcnJheSh2YWx1ZSkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHZhbHVlID0gewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGRlcHM6IHZhbHVlCiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9OwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIGlmICgodmFsdWUuZXhwb3J0cyB8fCB2YWx1ZS5pbml0KSAmJiAhdmFsdWUuZXhwb3J0c0ZuKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB2YWx1ZS5leHBvcnRzRm4gPSBjb250ZXh0Lm1ha2VTaGltRXhwb3J0cyh2YWx1ZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgc2hpbVtpZF0gPSB2YWx1ZTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICBjb25maWcuc2hpbSA9IHNoaW07CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgLy9BZGp1c3QgcGFja2FnZXMgaWYgbmVjZXNzYXJ5LgogICAgICAgICAgICAgICAgaWYgKGNmZy5wYWNrYWdlcykgewogICAgICAgICAgICAgICAgICAgIGVhY2goY2ZnLnBhY2thZ2VzLCBmdW5jdGlvbiAocGtnT2JqKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBsb2NhdGlvbiwgbmFtZTsKCiAgICAgICAgICAgICAgICAgICAgICAgIHBrZ09iaiA9IHR5cGVvZiBwa2dPYmogPT09ICdzdHJpbmcnID8ge25hbWU6IHBrZ09ian0gOiBwa2dPYmo7CgogICAgICAgICAgICAgICAgICAgICAgICBuYW1lID0gcGtnT2JqLm5hbWU7CiAgICAgICAgICAgICAgICAgICAgICAgIGxvY2F0aW9uID0gcGtnT2JqLmxvY2F0aW9uOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAobG9jYXRpb24pIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZy5wYXRoc1tuYW1lXSA9IHBrZ09iai5sb2NhdGlvbjsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9TYXZlIHBvaW50ZXIgdG8gbWFpbiBtb2R1bGUgSUQgZm9yIHBrZyBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICAvL1JlbW92ZSBsZWFkaW5nIGRvdCBpbiBtYWluLCBzbyBtYWluIHBhdGhzIGFyZSBub3JtYWxpemVkLAogICAgICAgICAgICAgICAgICAgICAgICAvL2FuZCByZW1vdmUgYW55IHRyYWlsaW5nIC5qcywgc2luY2UgZGlmZmVyZW50IHBhY2thZ2UKICAgICAgICAgICAgICAgICAgICAgICAgLy9lbnZzIGhhdmUgZGlmZmVyZW50IGNvbnZlbnRpb25zOiBzb21lIHVzZSBhIG1vZHVsZSBuYW1lLAogICAgICAgICAgICAgICAgICAgICAgICAvL3NvbWUgdXNlIGEgZmlsZSBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICBjb25maWcucGtnc1tuYW1lXSA9IHBrZ09iai5uYW1lICsgJy8nICsgKHBrZ09iai5tYWluIHx8ICdtYWluJykKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC5yZXBsYWNlKGN1cnJEaXJSZWdFeHAsICcnKQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLnJlcGxhY2UoanNTdWZmaXhSZWdFeHAsICcnKTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL0lmIHRoZXJlIGFyZSBhbnkgIndhaXRpbmcgdG8gZXhlY3V0ZSIgbW9kdWxlcyBpbiB0aGUgcmVnaXN0cnksCiAgICAgICAgICAgICAgICAvL3VwZGF0ZSB0aGUgbWFwcyBmb3IgdGhlbSwgc2luY2UgdGhlaXIgaW5mbywgbGlrZSBVUkxzIHRvIGxvYWQsCiAgICAgICAgICAgICAgICAvL21heSBoYXZlIGNoYW5nZWQuCiAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCwgaWQpIHsKICAgICAgICAgICAgICAgICAgICAvL0lmIG1vZHVsZSBhbHJlYWR5IGhhcyBpbml0IGNhbGxlZCwgc2luY2UgaXQgaXMgdG9vCiAgICAgICAgICAgICAgICAgICAgLy9sYXRlIHRvIG1vZGlmeSB0aGVtLCBhbmQgaWdub3JlIHVubm9ybWFsaXplZCBvbmVzCiAgICAgICAgICAgICAgICAgICAgLy9zaW5jZSB0aGV5IGFyZSB0cmFuc2llbnQuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFtb2QuaW5pdGVkICYmICFtb2QubWFwLnVubm9ybWFsaXplZCkgewogICAgICAgICAgICAgICAgICAgICAgICBtb2QubWFwID0gbWFrZU1vZHVsZU1hcChpZCwgbnVsbCwgdHJ1ZSk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgLy9JZiBhIGRlcHMgYXJyYXkgb3IgYSBjb25maWcgY2FsbGJhY2sgaXMgc3BlY2lmaWVkLCB0aGVuIGNhbGwKICAgICAgICAgICAgICAgIC8vcmVxdWlyZSB3aXRoIHRob3NlIGFyZ3MuIFRoaXMgaXMgdXNlZnVsIHdoZW4gcmVxdWlyZSBpcyBkZWZpbmVkIGFzIGEKICAgICAgICAgICAgICAgIC8vY29uZmlnIG9iamVjdCBiZWZvcmUgcmVxdWlyZS5qcyBpcyBsb2FkZWQuCiAgICAgICAgICAgICAgICBpZiAoY2ZnLmRlcHMgfHwgY2ZnLmNhbGxiYWNrKSB7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5yZXF1aXJlKGNmZy5kZXBzIHx8IFtdLCBjZmcuY2FsbGJhY2spOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAoKICAgICAgICAgICAgbWFrZVNoaW1FeHBvcnRzOiBmdW5jdGlvbiAodmFsdWUpIHsKICAgICAgICAgICAgICAgIGZ1bmN0aW9uIGZuKCkgewogICAgICAgICAgICAgICAgICAgIHZhciByZXQ7CiAgICAgICAgICAgICAgICAgICAgaWYgKHZhbHVlLmluaXQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0ID0gdmFsdWUuaW5pdC5hcHBseShnbG9iYWwsIGFyZ3VtZW50cyk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIHJldHVybiByZXQgfHwgKHZhbHVlLmV4cG9ydHMgJiYgZ2V0R2xvYmFsKHZhbHVlLmV4cG9ydHMpKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIHJldHVybiBmbjsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIG1ha2VSZXF1aXJlOiBmdW5jdGlvbiAocmVsTWFwLCBvcHRpb25zKSB7CiAgICAgICAgICAgICAgICBvcHRpb25zID0gb3B0aW9ucyB8fCB7fTsKCiAgICAgICAgICAgICAgICBmdW5jdGlvbiBsb2NhbFJlcXVpcmUoZGVwcywgY2FsbGJhY2ssIGVycmJhY2spIHsKICAgICAgICAgICAgICAgICAgICB2YXIgaWQsIG1hcCwgcmVxdWlyZU1vZDsKCiAgICAgICAgICAgICAgICAgICAgaWYgKG9wdGlvbnMuZW5hYmxlQnVpbGRDYWxsYmFjayAmJiBjYWxsYmFjayAmJiBpc0Z1bmN0aW9uKGNhbGxiYWNrKSkgewogICAgICAgICAgICAgICAgICAgICAgICBjYWxsYmFjay5fX3JlcXVpcmVKc0J1aWxkID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGlmICh0eXBlb2YgZGVwcyA9PT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGlzRnVuY3Rpb24oY2FsbGJhY2spKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL0ludmFsaWQgY2FsbAogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdyZXF1aXJlYXJncycsICdJbnZhbGlkIHJlcXVpcmUgY2FsbCcpLCBlcnJiYWNrKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9JZiByZXF1aXJlfGV4cG9ydHN8bW9kdWxlIGFyZSByZXF1ZXN0ZWQsIGdldCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy92YWx1ZSBmb3IgdGhlbSBmcm9tIHRoZSBzcGVjaWFsIGhhbmRsZXJzLiBDYXZlYXQ6CiAgICAgICAgICAgICAgICAgICAgICAgIC8vdGhpcyBvbmx5IHdvcmtzIHdoaWxlIG1vZHVsZSBpcyBiZWluZyBkZWZpbmVkLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAocmVsTWFwICYmIGhhc1Byb3AoaGFuZGxlcnMsIGRlcHMpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gaGFuZGxlcnNbZGVwc10ocmVnaXN0cnlbcmVsTWFwLmlkXSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vU3luY2hyb25vdXMgYWNjZXNzIHRvIG9uZSBtb2R1bGUuIElmIHJlcXVpcmUuZ2V0IGlzCiAgICAgICAgICAgICAgICAgICAgICAgIC8vYXZhaWxhYmxlIChhcyBpbiB0aGUgTm9kZSBhZGFwdGVyKSwgcHJlZmVyIHRoYXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChyZXEuZ2V0KSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gcmVxLmdldChjb250ZXh0LCBkZXBzLCByZWxNYXAsIGxvY2FsUmVxdWlyZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplIG1vZHVsZSBuYW1lLCBpZiBpdCBjb250YWlucyAuIG9yIC4uCiAgICAgICAgICAgICAgICAgICAgICAgIG1hcCA9IG1ha2VNb2R1bGVNYXAoZGVwcywgcmVsTWFwLCBmYWxzZSwgdHJ1ZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIGlkID0gbWFwLmlkOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKCFoYXNQcm9wKGRlZmluZWQsIGlkKSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdub3Rsb2FkZWQnLCAnTW9kdWxlIG5hbWUgIicgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgaWQgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJyIgaGFzIG5vdCBiZWVuIGxvYWRlZCB5ZXQgZm9yIGNvbnRleHQ6ICcgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dE5hbWUgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgKHJlbE1hcCA/ICcnIDogJy4gVXNlIHJlcXVpcmUoW10pJykpKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gZGVmaW5lZFtpZF07CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAvL0dyYWIgZGVmaW5lcyB3YWl0aW5nIGluIHRoZSBnbG9iYWwgcXVldWUuCiAgICAgICAgICAgICAgICAgICAgaW50YWtlRGVmaW5lcygpOwoKICAgICAgICAgICAgICAgICAgICAvL01hcmsgYWxsIHRoZSBkZXBlbmRlbmNpZXMgYXMgbmVlZGluZyB0byBiZSBsb2FkZWQuCiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5uZXh0VGljayhmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vU29tZSBkZWZpbmVzIGNvdWxkIGhhdmUgYmVlbiBhZGRlZCBzaW5jZSB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy9yZXF1aXJlIGNhbGwsIGNvbGxlY3QgdGhlbS4KICAgICAgICAgICAgICAgICAgICAgICAgaW50YWtlRGVmaW5lcygpOwoKICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZCA9IGdldE1vZHVsZShtYWtlTW9kdWxlTWFwKG51bGwsIHJlbE1hcCkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9TdG9yZSBpZiBtYXAgY29uZmlnIHNob3VsZCBiZSBhcHBsaWVkIHRvIHRoaXMgcmVxdWlyZQogICAgICAgICAgICAgICAgICAgICAgICAvL2NhbGwgZm9yIGRlcGVuZGVuY2llcy4KICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZC5za2lwTWFwID0gb3B0aW9ucy5za2lwTWFwOwoKICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZC5pbml0KGRlcHMsIGNhbGxiYWNrLCBlcnJiYWNrLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVkOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgICAgICAgICAgY2hlY2tMb2FkZWQoKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGxvY2FsUmVxdWlyZTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBtaXhpbihsb2NhbFJlcXVpcmUsIHsKICAgICAgICAgICAgICAgICAgICBpc0Jyb3dzZXI6IGlzQnJvd3NlciwKCiAgICAgICAgICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAgICAgICAgICogQ29udmVydHMgYSBtb2R1bGUgbmFtZSArIC5leHRlbnNpb24gaW50byBhbiBVUkwgcGF0aC4KICAgICAgICAgICAgICAgICAgICAgKiAqUmVxdWlyZXMqIHRoZSB1c2Ugb2YgYSBtb2R1bGUgbmFtZS4gSXQgZG9lcyBub3Qgc3VwcG9ydCB1c2luZwogICAgICAgICAgICAgICAgICAgICAqIHBsYWluIFVSTHMgbGlrZSBuYW1lVG9VcmwuCiAgICAgICAgICAgICAgICAgICAgICovCiAgICAgICAgICAgICAgICAgICAgdG9Vcmw6IGZ1bmN0aW9uIChtb2R1bGVOYW1lUGx1c0V4dCkgewogICAgICAgICAgICAgICAgICAgICAgICB2YXIgZXh0LAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaW5kZXggPSBtb2R1bGVOYW1lUGx1c0V4dC5sYXN0SW5kZXhPZignLicpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgc2VnbWVudCA9IG1vZHVsZU5hbWVQbHVzRXh0LnNwbGl0KCcvJylbMF0sCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBpc1JlbGF0aXZlID0gc2VnbWVudCA9PT0gJy4nIHx8IHNlZ21lbnQgPT09ICcuLic7CgogICAgICAgICAgICAgICAgICAgICAgICAvL0hhdmUgYSBmaWxlIGV4dGVuc2lvbiBhbGlhcywgYW5kIGl0IGlzIG5vdCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy9kb3RzIGZyb20gYSByZWxhdGl2ZSBwYXRoLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaW5kZXggIT09IC0xICYmICghaXNSZWxhdGl2ZSB8fCBpbmRleCA+IDEpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBleHQgPSBtb2R1bGVOYW1lUGx1c0V4dC5zdWJzdHJpbmcoaW5kZXgsIG1vZHVsZU5hbWVQbHVzRXh0Lmxlbmd0aCk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVOYW1lUGx1c0V4dCA9IG1vZHVsZU5hbWVQbHVzRXh0LnN1YnN0cmluZygwLCBpbmRleCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiBjb250ZXh0Lm5hbWVUb1VybChub3JtYWxpemUobW9kdWxlTmFtZVBsdXNFeHQsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlbE1hcCAmJiByZWxNYXAuaWQsIHRydWUpLCBleHQsICB0cnVlKTsKICAgICAgICAgICAgICAgICAgICB9LAoKICAgICAgICAgICAgICAgICAgICBkZWZpbmVkOiBmdW5jdGlvbiAoaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGhhc1Byb3AoZGVmaW5lZCwgbWFrZU1vZHVsZU1hcChpZCwgcmVsTWFwLCBmYWxzZSwgdHJ1ZSkuaWQpOwogICAgICAgICAgICAgICAgICAgIH0sCgogICAgICAgICAgICAgICAgICAgIHNwZWNpZmllZDogZnVuY3Rpb24gKGlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlkID0gbWFrZU1vZHVsZU1hcChpZCwgcmVsTWFwLCBmYWxzZSwgdHJ1ZSkuaWQ7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiBoYXNQcm9wKGRlZmluZWQsIGlkKSB8fCBoYXNQcm9wKHJlZ2lzdHJ5LCBpZCk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgLy9Pbmx5IGFsbG93IHVuZGVmIG9uIHRvcCBsZXZlbCByZXF1aXJlIGNhbGxzCiAgICAgICAgICAgICAgICBpZiAoIXJlbE1hcCkgewogICAgICAgICAgICAgICAgICAgIGxvY2FsUmVxdWlyZS51bmRlZiA9IGZ1bmN0aW9uIChpZCkgewogICAgICAgICAgICAgICAgICAgICAgICAvL0JpbmQgYW55IHdhaXRpbmcgZGVmaW5lKCkgY2FsbHMgdG8gdGhpcyBjb250ZXh0LAogICAgICAgICAgICAgICAgICAgICAgICAvL2ZpeCBmb3IgIzQwOAogICAgICAgICAgICAgICAgICAgICAgICB0YWtlR2xvYmFsUXVldWUoKTsKCiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBtYXAgPSBtYWtlTW9kdWxlTWFwKGlkLCByZWxNYXAsIHRydWUpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kID0gZ2V0T3duKHJlZ2lzdHJ5LCBpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBtb2QudW5kZWZlZCA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgICAgIHJlbW92ZVNjcmlwdChpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBkZWxldGUgZGVmaW5lZFtpZF07CiAgICAgICAgICAgICAgICAgICAgICAgIGRlbGV0ZSB1cmxGZXRjaGVkW21hcC51cmxdOwogICAgICAgICAgICAgICAgICAgICAgICBkZWxldGUgdW5kZWZFdmVudHNbaWRdOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9DbGVhbiBxdWV1ZWQgZGVmaW5lcyB0b28uIEdvIGJhY2t3YXJkcwogICAgICAgICAgICAgICAgICAgICAgICAvL2luIGFycmF5IHNvIHRoYXQgdGhlIHNwbGljZXMgZG8gbm90CiAgICAgICAgICAgICAgICAgICAgICAgIC8vbWVzcyB1cCB0aGUgaXRlcmF0aW9uLgogICAgICAgICAgICAgICAgICAgICAgICBlYWNoUmV2ZXJzZShkZWZRdWV1ZSwgZnVuY3Rpb24oYXJncywgaSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IGlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZGVmUXVldWUuc3BsaWNlKGksIDEpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICAgICAgZGVsZXRlIGNvbnRleHQuZGVmUXVldWVNYXBbaWRdOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9Ib2xkIG9uIHRvIGxpc3RlbmVycyBpbiBjYXNlIHRoZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9tb2R1bGUgd2lsbCBiZSBhdHRlbXB0ZWQgdG8gYmUgcmVsb2FkZWQKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vdXNpbmcgYSBkaWZmZXJlbnQgY29uZmlnLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZC5ldmVudHMuZGVmaW5lZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHVuZGVmRXZlbnRzW2lkXSA9IG1vZC5ldmVudHM7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgY2xlYW5SZWdpc3RyeShpZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9OwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHJldHVybiBsb2NhbFJlcXVpcmU7CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogQ2FsbGVkIHRvIGVuYWJsZSBhIG1vZHVsZSBpZiBpdCBpcyBzdGlsbCBpbiB0aGUgcmVnaXN0cnkKICAgICAgICAgICAgICogYXdhaXRpbmcgZW5hYmxlbWVudC4gQSBzZWNvbmQgYXJnLCBwYXJlbnQsIHRoZSBwYXJlbnQgbW9kdWxlLAogICAgICAgICAgICAgKiBpcyBwYXNzZWQgaW4gZm9yIGNvbnRleHQsIHdoZW4gdGhpcyBtZXRob2QgaXMgb3ZlcnJpZGRlbiBieQogICAgICAgICAgICAgKiB0aGUgb3B0aW1pemVyLiBOb3Qgc2hvd24gaGVyZSB0byBrZWVwIGNvZGUgY29tcGFjdC4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGVuYWJsZTogZnVuY3Rpb24gKGRlcE1hcCkgewogICAgICAgICAgICAgICAgdmFyIG1vZCA9IGdldE93bihyZWdpc3RyeSwgZGVwTWFwLmlkKTsKICAgICAgICAgICAgICAgIGlmIChtb2QpIHsKICAgICAgICAgICAgICAgICAgICBnZXRNb2R1bGUoZGVwTWFwKS5lbmFibGUoKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIC8qKgogICAgICAgICAgICAgKiBJbnRlcm5hbCBtZXRob2QgdXNlZCBieSBlbnZpcm9ubWVudCBhZGFwdGVycyB0byBjb21wbGV0ZSBhIGxvYWQgZXZlbnQuCiAgICAgICAgICAgICAqIEEgbG9hZCBldmVudCBjb3VsZCBiZSBhIHNjcmlwdCBsb2FkIG9yIGp1c3QgYSBsb2FkIHBhc3MgZnJvbSBhIHN5bmNocm9ub3VzCiAgICAgICAgICAgICAqIGxvYWQgY2FsbC4KICAgICAgICAgICAgICogQHBhcmFtIHtTdHJpbmd9IG1vZHVsZU5hbWUgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZSB0byBwb3RlbnRpYWxseSBjb21wbGV0ZS4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGNvbXBsZXRlTG9hZDogZnVuY3Rpb24gKG1vZHVsZU5hbWUpIHsKICAgICAgICAgICAgICAgIHZhciBmb3VuZCwgYXJncywgbW9kLAogICAgICAgICAgICAgICAgICAgIHNoaW0gPSBnZXRPd24oY29uZmlnLnNoaW0sIG1vZHVsZU5hbWUpIHx8IHt9LAogICAgICAgICAgICAgICAgICAgIHNoRXhwb3J0cyA9IHNoaW0uZXhwb3J0czsKCiAgICAgICAgICAgICAgICB0YWtlR2xvYmFsUXVldWUoKTsKCiAgICAgICAgICAgICAgICB3aGlsZSAoZGVmUXVldWUubGVuZ3RoKSB7CiAgICAgICAgICAgICAgICAgICAgYXJncyA9IGRlZlF1ZXVlLnNoaWZ0KCk7CiAgICAgICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IG51bGwpIHsKICAgICAgICAgICAgICAgICAgICAgICAgYXJnc1swXSA9IG1vZHVsZU5hbWU7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vSWYgYWxyZWFkeSBmb3VuZCBhbiBhbm9ueW1vdXMgbW9kdWxlIGFuZCBib3VuZCBpdAogICAgICAgICAgICAgICAgICAgICAgICAvL3RvIHRoaXMgbmFtZSwgdGhlbiB0aGlzIGlzIHNvbWUgb3RoZXIgYW5vbiBtb2R1bGUKICAgICAgICAgICAgICAgICAgICAgICAgLy93YWl0aW5nIGZvciBpdHMgY29tcGxldGVMb2FkIHRvIGZpcmUuCiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChmb3VuZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgZm91bmQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSBpZiAoYXJnc1swXSA9PT0gbW9kdWxlTmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICAvL0ZvdW5kIG1hdGNoaW5nIGRlZmluZSBjYWxsIGZvciB0aGlzIHNjcmlwdCEKICAgICAgICAgICAgICAgICAgICAgICAgZm91bmQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgY2FsbEdldE1vZHVsZShhcmdzKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNvbnRleHQuZGVmUXVldWVNYXAgPSB7fTsKCiAgICAgICAgICAgICAgICAvL0RvIHRoaXMgYWZ0ZXIgdGhlIGN5Y2xlIG9mIGNhbGxHZXRNb2R1bGUgaW4gY2FzZSB0aGUgcmVzdWx0CiAgICAgICAgICAgICAgICAvL29mIHRob3NlIGNhbGxzL2luaXQgY2FsbHMgY2hhbmdlcyB0aGUgcmVnaXN0cnkuCiAgICAgICAgICAgICAgICBtb2QgPSBnZXRPd24ocmVnaXN0cnksIG1vZHVsZU5hbWUpOwoKICAgICAgICAgICAgICAgIGlmICghZm91bmQgJiYgIWhhc1Byb3AoZGVmaW5lZCwgbW9kdWxlTmFtZSkgJiYgbW9kICYmICFtb2QuaW5pdGVkKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKGNvbmZpZy5lbmZvcmNlRGVmaW5lICYmICghc2hFeHBvcnRzIHx8ICFnZXRHbG9iYWwoc2hFeHBvcnRzKSkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGhhc1BhdGhGYWxsYmFjayhtb2R1bGVOYW1lKSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdub2RlZmluZScsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICdObyBkZWZpbmUgY2FsbCBmb3IgJyArIG1vZHVsZU5hbWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIG51bGwsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIFttb2R1bGVOYW1lXSkpOwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9BIHNjcmlwdCB0aGF0IGRvZXMgbm90IGNhbGwgZGVmaW5lKCksIHNvIGp1c3Qgc2ltdWxhdGUKICAgICAgICAgICAgICAgICAgICAgICAgLy90aGUgY2FsbCBmb3IgaXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGNhbGxHZXRNb2R1bGUoW21vZHVsZU5hbWUsIChzaGltLmRlcHMgfHwgW10pLCBzaGltLmV4cG9ydHNGbl0pOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBjaGVja0xvYWRlZCgpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIENvbnZlcnRzIGEgbW9kdWxlIG5hbWUgdG8gYSBmaWxlIHBhdGguIFN1cHBvcnRzIGNhc2VzIHdoZXJlCiAgICAgICAgICAgICAqIG1vZHVsZU5hbWUgbWF5IGFjdHVhbGx5IGJlIGp1c3QgYW4gVVJMLgogICAgICAgICAgICAgKiBOb3RlIHRoYXQgaXQgKipkb2VzIG5vdCoqIGNhbGwgbm9ybWFsaXplIG9uIHRoZSBtb2R1bGVOYW1lLAogICAgICAgICAgICAgKiBpdCBpcyBhc3N1bWVkIHRvIGhhdmUgYWxyZWFkeSBiZWVuIG5vcm1hbGl6ZWQuIFRoaXMgaXMgYW4KICAgICAgICAgICAgICogaW50ZXJuYWwgQVBJLCBub3QgYSBwdWJsaWMgb25lLiBVc2UgdG9VcmwgZm9yIHRoZSBwdWJsaWMgQVBJLgogICAgICAgICAgICAgKi8KICAgICAgICAgICAgbmFtZVRvVXJsOiBmdW5jdGlvbiAobW9kdWxlTmFtZSwgZXh0LCBza2lwRXh0KSB7CiAgICAgICAgICAgICAgICB2YXIgcGF0aHMsIHN5bXMsIGksIHBhcmVudE1vZHVsZSwgdXJsLAogICAgICAgICAgICAgICAgICAgIHBhcmVudFBhdGgsIGJ1bmRsZUlkLAogICAgICAgICAgICAgICAgICAgIHBrZ01haW4gPSBnZXRPd24oY29uZmlnLnBrZ3MsIG1vZHVsZU5hbWUpOwoKICAgICAgICAgICAgICAgIGlmIChwa2dNYWluKSB7CiAgICAgICAgICAgICAgICAgICAgbW9kdWxlTmFtZSA9IHBrZ01haW47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgYnVuZGxlSWQgPSBnZXRPd24oYnVuZGxlc01hcCwgbW9kdWxlTmFtZSk7CgogICAgICAgICAgICAgICAgaWYgKGJ1bmRsZUlkKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGNvbnRleHQubmFtZVRvVXJsKGJ1bmRsZUlkLCBleHQsIHNraXBFeHQpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vSWYgYSBjb2xvbiBpcyBpbiB0aGUgVVJMLCBpdCBpbmRpY2F0ZXMgYSBwcm90b2NvbCBpcyB1c2VkIGFuZCBpdCBpcyBqdXN0CiAgICAgICAgICAgICAgICAvL2FuIFVSTCB0byBhIGZpbGUsIG9yIGlmIGl0IHN0YXJ0cyB3aXRoIGEgc2xhc2gsIGNvbnRhaW5zIGEgcXVlcnkgYXJnIChpLmUuID8pCiAgICAgICAgICAgICAgICAvL29yIGVuZHMgd2l0aCAuanMsIHRoZW4gYXNzdW1lIHRoZSB1c2VyIG1lYW50IHRvIHVzZSBhbiB1cmwgYW5kIG5vdCBhIG1vZHVsZSBpZC4KICAgICAgICAgICAgICAgIC8vVGhlIHNsYXNoIGlzIGltcG9ydGFudCBmb3IgcHJvdG9jb2wtbGVzcyBVUkxzIGFzIHdlbGwgYXMgZnVsbCBwYXRocy4KICAgICAgICAgICAgICAgIGlmIChyZXEuanNFeHRSZWdFeHAudGVzdChtb2R1bGVOYW1lKSkgewogICAgICAgICAgICAgICAgICAgIC8vSnVzdCBhIHBsYWluIHBhdGgsIG5vdCBtb2R1bGUgbmFtZSBsb29rdXAsIHNvIGp1c3QgcmV0dXJuIGl0LgogICAgICAgICAgICAgICAgICAgIC8vQWRkIGV4dGVuc2lvbiBpZiBpdCBpcyBpbmNsdWRlZC4gVGhpcyBpcyBhIGJpdCB3b25reSwgb25seSBub24tLmpzIHRoaW5ncyBwYXNzCiAgICAgICAgICAgICAgICAgICAgLy9hbiBleHRlbnNpb24sIHRoaXMgbWV0aG9kIHByb2JhYmx5IG5lZWRzIHRvIGJlIHJld29ya2VkLgogICAgICAgICAgICAgICAgICAgIHVybCA9IG1vZHVsZU5hbWUgKyAoZXh0IHx8ICcnKTsKICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgLy9BIG1vZHVsZSB0aGF0IG5lZWRzIHRvIGJlIGNvbnZlcnRlZCB0byBhIHBhdGguCiAgICAgICAgICAgICAgICAgICAgcGF0aHMgPSBjb25maWcucGF0aHM7CgogICAgICAgICAgICAgICAgICAgIHN5bXMgPSBtb2R1bGVOYW1lLnNwbGl0KCcvJyk7CiAgICAgICAgICAgICAgICAgICAgLy9Gb3IgZWFjaCBtb2R1bGUgbmFtZSBzZWdtZW50LCBzZWUgaWYgdGhlcmUgaXMgYSBwYXRoCiAgICAgICAgICAgICAgICAgICAgLy9yZWdpc3RlcmVkIGZvciBpdC4gU3RhcnQgd2l0aCBtb3N0IHNwZWNpZmljIG5hbWUKICAgICAgICAgICAgICAgICAgICAvL2FuZCB3b3JrIHVwIGZyb20gaXQuCiAgICAgICAgICAgICAgICAgICAgZm9yIChpID0gc3ltcy5sZW5ndGg7IGkgPiAwOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcGFyZW50TW9kdWxlID0gc3ltcy5zbGljZSgwLCBpKS5qb2luKCcvJyk7CgogICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnRQYXRoID0gZ2V0T3duKHBhdGhzLCBwYXJlbnRNb2R1bGUpOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAocGFyZW50UGF0aCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9JZiBhbiBhcnJheSwgaXQgbWVhbnMgdGhlcmUgYXJlIGEgZmV3IGNob2ljZXMsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL0Nob29zZSB0aGUgb25lIHRoYXQgaXMgZGVzaXJlZAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGlzQXJyYXkocGFyZW50UGF0aCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnRQYXRoID0gcGFyZW50UGF0aFswXTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIHN5bXMuc3BsaWNlKDAsIGksIHBhcmVudFBhdGgpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIC8vSm9pbiB0aGUgcGF0aCBwYXJ0cyB0b2dldGhlciwgdGhlbiBmaWd1cmUgb3V0IGlmIGJhc2VVcmwgaXMgbmVlZGVkLgogICAgICAgICAgICAgICAgICAgIHVybCA9IHN5bXMuam9pbignLycpOwogICAgICAgICAgICAgICAgICAgIHVybCArPSAoZXh0IHx8ICgvXmRhdGFcOnxcPy8udGVzdCh1cmwpIHx8IHNraXBFeHQgPyAnJyA6ICcuanMnKSk7CiAgICAgICAgICAgICAgICAgICAgdXJsID0gKHVybC5jaGFyQXQoMCkgPT09ICcvJyB8fCB1cmwubWF0Y2goL15bXHdcK1wuXC1dKzovKSA/ICcnIDogY29uZmlnLmJhc2VVcmwpICsgdXJsOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHJldHVybiBjb25maWcudXJsQXJncyA/IHVybCArCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAoKHVybC5pbmRleE9mKCc/JykgPT09IC0xID8gJz8nIDogJyYnKSArCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgY29uZmlnLnVybEFyZ3MpIDogdXJsOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLy9EZWxlZ2F0ZXMgdG8gcmVxLmxvYWQuIEJyb2tlbiBvdXQgYXMgYSBzZXBhcmF0ZSBmdW5jdGlvbiB0bwogICAgICAgICAgICAvL2FsbG93IG92ZXJyaWRpbmcgaW4gdGhlIG9wdGltaXplci4KICAgICAgICAgICAgbG9hZDogZnVuY3Rpb24gKGlkLCB1cmwpIHsKICAgICAgICAgICAgICAgIHJlcS5sb2FkKGNvbnRleHQsIGlkLCB1cmwpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIEV4ZWN1dGVzIGEgbW9kdWxlIGNhbGxiYWNrIGZ1bmN0aW9uLiBCcm9rZW4gb3V0IGFzIGEgc2VwYXJhdGUgZnVuY3Rpb24KICAgICAgICAgICAgICogc29sZWx5IHRvIGFsbG93IHRoZSBidWlsZCBzeXN0ZW0gdG8gc2VxdWVuY2UgdGhlIGZpbGVzIGluIHRoZSBidWlsdAogICAgICAgICAgICAgKiBsYXllciBpbiB0aGUgcmlnaHQgc2VxdWVuY2UuCiAgICAgICAgICAgICAqCiAgICAgICAgICAgICAqIEBwcml2YXRlCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBleGVjQ2I6IGZ1bmN0aW9uIChuYW1lLCBjYWxsYmFjaywgYXJncywgZXhwb3J0cykgewogICAgICAgICAgICAgICAgcmV0dXJuIGNhbGxiYWNrLmFwcGx5KGV4cG9ydHMsIGFyZ3MpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIGNhbGxiYWNrIGZvciBzY3JpcHQgbG9hZHMsIHVzZWQgdG8gY2hlY2sgc3RhdHVzIG9mIGxvYWRpbmcuCiAgICAgICAgICAgICAqCiAgICAgICAgICAgICAqIEBwYXJhbSB7RXZlbnR9IGV2dCB0aGUgZXZlbnQgZnJvbSB0aGUgYnJvd3NlciBmb3IgdGhlIHNjcmlwdAogICAgICAgICAgICAgKiB0aGF0IHdhcyBsb2FkZWQuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBvblNjcmlwdExvYWQ6IGZ1bmN0aW9uIChldnQpIHsKICAgICAgICAgICAgICAgIC8vVXNpbmcgY3VycmVudFRhcmdldCBpbnN0ZWFkIG9mIHRhcmdldCBmb3IgRmlyZWZveCAyLjAncyBzYWtlLiBOb3QKICAgICAgICAgICAgICAgIC8vYWxsIG9sZCBicm93c2VycyB3aWxsIGJlIHN1cHBvcnRlZCwgYnV0IHRoaXMgb25lIHdhcyBlYXN5IGVub3VnaAogICAgICAgICAgICAgICAgLy90byBzdXBwb3J0IGFuZCBzdGlsbCBtYWtlcyBzZW5zZS4KICAgICAgICAgICAgICAgIGlmIChldnQudHlwZSA9PT0gJ2xvYWQnIHx8CiAgICAgICAgICAgICAgICAgICAgICAgIChyZWFkeVJlZ0V4cC50ZXN0KChldnQuY3VycmVudFRhcmdldCB8fCBldnQuc3JjRWxlbWVudCkucmVhZHlTdGF0ZSkpKSB7CiAgICAgICAgICAgICAgICAgICAgLy9SZXNldCBpbnRlcmFjdGl2ZSBzY3JpcHQgc28gYSBzY3JpcHQgbm9kZSBpcyBub3QgaGVsZCBvbnRvIGZvcgogICAgICAgICAgICAgICAgICAgIC8vdG8gbG9uZy4KICAgICAgICAgICAgICAgICAgICBpbnRlcmFjdGl2ZVNjcmlwdCA9IG51bGw7CgogICAgICAgICAgICAgICAgICAgIC8vUHVsbCBvdXQgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZSBhbmQgdGhlIGNvbnRleHQuCiAgICAgICAgICAgICAgICAgICAgdmFyIGRhdGEgPSBnZXRTY3JpcHREYXRhKGV2dCk7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5jb21wbGV0ZUxvYWQoZGF0YS5pZCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogQ2FsbGJhY2sgZm9yIHNjcmlwdCBlcnJvcnMuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBvblNjcmlwdEVycm9yOiBmdW5jdGlvbiAoZXZ0KSB7CiAgICAgICAgICAgICAgICB2YXIgZGF0YSA9IGdldFNjcmlwdERhdGEoZXZ0KTsKICAgICAgICAgICAgICAgIGlmICghaGFzUGF0aEZhbGxiYWNrKGRhdGEuaWQpKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIHBhcmVudHMgPSBbXTsKICAgICAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24odmFsdWUsIGtleSkgewogICAgICAgICAgICAgICAgICAgICAgICBpZiAoa2V5LmluZGV4T2YoJ19AcicpICE9PSAwKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlYWNoKHZhbHVlLmRlcE1hcHMsIGZ1bmN0aW9uKGRlcE1hcCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmIChkZXBNYXAuaWQgPT09IGRhdGEuaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcGFyZW50cy5wdXNoKGtleSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ3NjcmlwdGVycm9yJywgJ1NjcmlwdCBlcnJvciBmb3IgIicgKyBkYXRhLmlkICsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgKHBhcmVudHMubGVuZ3RoID8KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJyIsIG5lZWRlZCBieTogJyArIHBhcmVudHMuam9pbignLCAnKSA6CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICciJyksIGV2dCwgW2RhdGEuaWRdKSk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBjb250ZXh0LnJlcXVpcmUgPSBjb250ZXh0Lm1ha2VSZXF1aXJlKCk7CiAgICAgICAgcmV0dXJuIGNvbnRleHQ7CiAgICB9CgogICAgLyoqCiAgICAgKiBNYWluIGVudHJ5IHBvaW50LgogICAgICoKICAgICAqIElmIHRoZSBvbmx5IGFyZ3VtZW50IHRvIHJlcXVpcmUgaXMgYSBzdHJpbmcsIHRoZW4gdGhlIG1vZHVsZSB0aGF0CiAgICAgKiBpcyByZXByZXNlbnRlZCBieSB0aGF0IHN0cmluZyBpcyBmZXRjaGVkIGZvciB0aGUgYXBwcm9wcmlhdGUgY29udGV4dC4KICAgICAqCiAgICAgKiBJZiB0aGUgZmlyc3QgYXJndW1lbnQgaXMgYW4gYXJyYXksIHRoZW4gaXQgd2lsbCBiZSB0cmVhdGVkIGFzIGFuIGFycmF5CiAgICAgKiBvZiBkZXBlbmRlbmN5IHN0cmluZyBuYW1lcyB0byBmZXRjaC4gQW4gb3B0aW9uYWwgZnVuY3Rpb24gY2FsbGJhY2sgY2FuCiAgICAgKiBiZSBzcGVjaWZpZWQgdG8gZXhlY3V0ZSB3aGVuIGFsbCBvZiB0aG9zZSBkZXBlbmRlbmNpZXMgYXJlIGF2YWlsYWJsZS4KICAgICAqCiAgICAgKiBNYWtlIGEgbG9jYWwgcmVxIHZhcmlhYmxlIHRvIGhlbHAgQ2FqYSBjb21wbGlhbmNlIChpdCBhc3N1bWVzIHRoaW5ncwogICAgICogb24gYSByZXF1aXJlIHRoYXQgYXJlIG5vdCBzdGFuZGFyZGl6ZWQpLCBhbmQgdG8gZ2l2ZSBhIHNob3J0CiAgICAgKiBuYW1lIGZvciBtaW5pZmljYXRpb24vbG9jYWwgc2NvcGUgdXNlLgogICAgICovCiAgICByZXEgPSByZXF1aXJlanMgPSBmdW5jdGlvbiAoZGVwcywgY2FsbGJhY2ssIGVycmJhY2ssIG9wdGlvbmFsKSB7CgogICAgICAgIC8vRmluZCB0aGUgcmlnaHQgY29udGV4dCwgdXNlIGRlZmF1bHQKICAgICAgICB2YXIgY29udGV4dCwgY29uZmlnLAogICAgICAgICAgICBjb250ZXh0TmFtZSA9IGRlZkNvbnRleHROYW1lOwoKICAgICAgICAvLyBEZXRlcm1pbmUgaWYgaGF2ZSBjb25maWcgb2JqZWN0IGluIHRoZSBjYWxsLgogICAgICAgIGlmICghaXNBcnJheShkZXBzKSAmJiB0eXBlb2YgZGVwcyAhPT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgLy8gZGVwcyBpcyBhIGNvbmZpZyBvYmplY3QKICAgICAgICAgICAgY29uZmlnID0gZGVwczsKICAgICAgICAgICAgaWYgKGlzQXJyYXkoY2FsbGJhY2spKSB7CiAgICAgICAgICAgICAgICAvLyBBZGp1c3QgYXJncyBpZiB0aGVyZSBhcmUgZGVwZW5kZW5jaWVzCiAgICAgICAgICAgICAgICBkZXBzID0gY2FsbGJhY2s7CiAgICAgICAgICAgICAgICBjYWxsYmFjayA9IGVycmJhY2s7CiAgICAgICAgICAgICAgICBlcnJiYWNrID0gb3B0aW9uYWw7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBkZXBzID0gW107CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIGlmIChjb25maWcgJiYgY29uZmlnLmNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dE5hbWUgPSBjb25maWcuY29udGV4dDsKICAgICAgICB9CgogICAgICAgIGNvbnRleHQgPSBnZXRPd24oY29udGV4dHMsIGNvbnRleHROYW1lKTsKICAgICAgICBpZiAoIWNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dCA9IGNvbnRleHRzW2NvbnRleHROYW1lXSA9IHJlcS5zLm5ld0NvbnRleHQoY29udGV4dE5hbWUpOwogICAgICAgIH0KCiAgICAgICAgaWYgKGNvbmZpZykgewogICAgICAgICAgICBjb250ZXh0LmNvbmZpZ3VyZShjb25maWcpOwogICAgICAgIH0KCiAgICAgICAgcmV0dXJuIGNvbnRleHQucmVxdWlyZShkZXBzLCBjYWxsYmFjaywgZXJyYmFjayk7CiAgICB9OwoKICAgIC8qKgogICAgICogU3VwcG9ydCByZXF1aXJlLmNvbmZpZygpIHRvIG1ha2UgaXQgZWFzaWVyIHRvIGNvb3BlcmF0ZSB3aXRoIG90aGVyCiAgICAgKiBBTUQgbG9hZGVycyBvbiBnbG9iYWxseSBhZ3JlZWQgbmFtZXMuCiAgICAgKi8KICAgIHJlcS5jb25maWcgPSBmdW5jdGlvbiAoY29uZmlnKSB7CiAgICAgICAgcmV0dXJuIHJlcShjb25maWcpOwogICAgfTsKCiAgICAvKioKICAgICAqIEV4ZWN1dGUgc29tZXRoaW5nIGFmdGVyIHRoZSBjdXJyZW50IHRpY2sKICAgICAqIG9mIHRoZSBldmVudCBsb29wLiBPdmVycmlkZSBmb3Igb3RoZXIgZW52cwogICAgICogdGhhdCBoYXZlIGEgYmV0dGVyIHNvbHV0aW9uIHRoYW4gc2V0VGltZW91dC4KICAgICAqIEBwYXJhbSAge0Z1bmN0aW9ufSBmbiBmdW5jdGlvbiB0byBleGVjdXRlIGxhdGVyLgogICAgICovCiAgICByZXEubmV4dFRpY2sgPSB0eXBlb2Ygc2V0VGltZW91dCAhPT0gJ3VuZGVmaW5lZCcgPyBmdW5jdGlvbiAoZm4pIHsKICAgICAgICBzZXRUaW1lb3V0KGZuLCA0KTsKICAgIH0gOiBmdW5jdGlvbiAoZm4pIHsgZm4oKTsgfTsKCiAgICAvKioKICAgICAqIEV4cG9ydCByZXF1aXJlIGFzIGEgZ2xvYmFsLCBidXQgb25seSBpZiBpdCBkb2VzIG5vdCBhbHJlYWR5IGV4aXN0LgogICAgICovCiAgICBpZiAoIXJlcXVpcmUpIHsKICAgICAgICByZXF1aXJlID0gcmVxOwogICAgfQoKICAgIHJlcS52ZXJzaW9uID0gdmVyc2lvbjsKCiAgICAvL1VzZWQgdG8gZmlsdGVyIG91dCBkZXBlbmRlbmNpZXMgdGhhdCBhcmUgYWxyZWFkeSBwYXRocy4KICAgIHJlcS5qc0V4dFJlZ0V4cCA9IC9eXC98OnxcP3xcLmpzJC87CiAgICByZXEuaXNCcm93c2VyID0gaXNCcm93c2VyOwogICAgcyA9IHJlcS5zID0gewogICAgICAgIGNvbnRleHRzOiBjb250ZXh0cywKICAgICAgICBuZXdDb250ZXh0OiBuZXdDb250ZXh0CiAgICB9OwoKICAgIC8vQ3JlYXRlIGRlZmF1bHQgY29udGV4dC4KICAgIHJlcSh7fSk7CgogICAgLy9FeHBvcnRzIHNvbWUgY29udGV4dC1zZW5zaXRpdmUgbWV0aG9kcyBvbiBnbG9iYWwgcmVxdWlyZS4KICAgIGVhY2goWwogICAgICAgICd0b1VybCcsCiAgICAgICAgJ3VuZGVmJywKICAgICAgICAnZGVmaW5lZCcsCiAgICAgICAgJ3NwZWNpZmllZCcKICAgIF0sIGZ1bmN0aW9uIChwcm9wKSB7CiAgICAgICAgLy9SZWZlcmVuY2UgZnJvbSBjb250ZXh0cyBpbnN0ZWFkIG9mIGVhcmx5IGJpbmRpbmcgdG8gZGVmYXVsdCBjb250ZXh0LAogICAgICAgIC8vc28gdGhhdCBkdXJpbmcgYnVpbGRzLCB0aGUgbGF0ZXN0IGluc3RhbmNlIG9mIHRoZSBkZWZhdWx0IGNvbnRleHQKICAgICAgICAvL3dpdGggaXRzIGNvbmZpZyBnZXRzIHVzZWQuCiAgICAgICAgcmVxW3Byb3BdID0gZnVuY3Rpb24gKCkgewogICAgICAgICAgICB2YXIgY3R4ID0gY29udGV4dHNbZGVmQ29udGV4dE5hbWVdOwogICAgICAgICAgICByZXR1cm4gY3R4LnJlcXVpcmVbcHJvcF0uYXBwbHkoY3R4LCBhcmd1bWVudHMpOwogICAgICAgIH07CiAgICB9KTsKCiAgICBpZiAoaXNCcm93c2VyKSB7CiAgICAgICAgaGVhZCA9IHMuaGVhZCA9IGRvY3VtZW50LmdldEVsZW1lbnRzQnlUYWdOYW1lKCdoZWFkJylbMF07CiAgICAgICAgLy9JZiBCQVNFIHRhZyBpcyBpbiBwbGF5LCB1c2luZyBhcHBlbmRDaGlsZCBpcyBhIHByb2JsZW0gZm9yIElFNi4KICAgICAgICAvL1doZW4gdGhhdCBicm93c2VyIGRpZXMsIHRoaXMgY2FuIGJlIHJlbW92ZWQuIERldGFpbHMgaW4gdGhpcyBqUXVlcnkgYnVnOgogICAgICAgIC8vaHR0cDovL2Rldi5qcXVlcnkuY29tL3RpY2tldC8yNzA5CiAgICAgICAgYmFzZUVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50c0J5VGFnTmFtZSgnYmFzZScpWzBdOwogICAgICAgIGlmIChiYXNlRWxlbWVudCkgewogICAgICAgICAgICBoZWFkID0gcy5oZWFkID0gYmFzZUVsZW1lbnQucGFyZW50Tm9kZTsKICAgICAgICB9CiAgICB9CgogICAgLyoqCiAgICAgKiBBbnkgZXJyb3JzIHRoYXQgcmVxdWlyZSBleHBsaWNpdGx5IGdlbmVyYXRlcyB3aWxsIGJlIHBhc3NlZCB0byB0aGlzCiAgICAgKiBmdW5jdGlvbi4gSW50ZXJjZXB0L292ZXJyaWRlIGl0IGlmIHlvdSB3YW50IGN1c3RvbSBlcnJvciBoYW5kbGluZy4KICAgICAqIEBwYXJhbSB7RXJyb3J9IGVyciB0aGUgZXJyb3Igb2JqZWN0LgogICAgICovCiAgICByZXEub25FcnJvciA9IGRlZmF1bHRPbkVycm9yOwoKICAgIC8qKgogICAgICogQ3JlYXRlcyB0aGUgbm9kZSBmb3IgdGhlIGxvYWQgY29tbWFuZC4gT25seSB1c2VkIGluIGJyb3dzZXIgZW52cy4KICAgICAqLwogICAgcmVxLmNyZWF0ZU5vZGUgPSBmdW5jdGlvbiAoY29uZmlnLCBtb2R1bGVOYW1lLCB1cmwpIHsKICAgICAgICB2YXIgbm9kZSA9IGNvbmZpZy54aHRtbCA/CiAgICAgICAgICAgICAgICBkb2N1bWVudC5jcmVhdGVFbGVtZW50TlMoJ2h0dHA6Ly93d3cudzMub3JnLzE5OTkveGh0bWwnLCAnaHRtbDpzY3JpcHQnKSA6CiAgICAgICAgICAgICAgICBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdzY3JpcHQnKTsKICAgICAgICBub2RlLnR5cGUgPSBjb25maWcuc2NyaXB0VHlwZSB8fCAndGV4dC9qYXZhc2NyaXB0JzsKICAgICAgICBub2RlLmNoYXJzZXQgPSAndXRmLTgnOwogICAgICAgIG5vZGUuYXN5bmMgPSB0cnVlOwogICAgICAgIHJldHVybiBub2RlOwogICAgfTsKCiAgICAvKioKICAgICAqIERvZXMgdGhlIHJlcXVlc3QgdG8gbG9hZCBhIG1vZHVsZSBmb3IgdGhlIGJyb3dzZXIgY2FzZS4KICAgICAqIE1ha2UgdGhpcyBhIHNlcGFyYXRlIGZ1bmN0aW9uIHRvIGFsbG93IG90aGVyIGVudmlyb25tZW50cwogICAgICogdG8gb3ZlcnJpZGUgaXQuCiAgICAgKgogICAgICogQHBhcmFtIHtPYmplY3R9IGNvbnRleHQgdGhlIHJlcXVpcmUgY29udGV4dCB0byBmaW5kIHN0YXRlLgogICAgICogQHBhcmFtIHtTdHJpbmd9IG1vZHVsZU5hbWUgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZS4KICAgICAqIEBwYXJhbSB7T2JqZWN0fSB1cmwgdGhlIFVSTCB0byB0aGUgbW9kdWxlLgogICAgICovCiAgICByZXEubG9hZCA9IGZ1bmN0aW9uIChjb250ZXh0LCBtb2R1bGVOYW1lLCB1cmwpIHsKICAgICAgICB2YXIgY29uZmlnID0gKGNvbnRleHQgJiYgY29udGV4dC5jb25maWcpIHx8IHt9LAogICAgICAgICAgICBub2RlOwogICAgICAgIGlmIChpc0Jyb3dzZXIpIHsKICAgICAgICAgICAgLy9JbiB0aGUgYnJvd3NlciBzbyB1c2UgYSBzY3JpcHQgdGFnCiAgICAgICAgICAgIG5vZGUgPSByZXEuY3JlYXRlTm9kZShjb25maWcsIG1vZHVsZU5hbWUsIHVybCk7CiAgICAgICAgICAgIGlmIChjb25maWcub25Ob2RlQ3JlYXRlZCkgewogICAgICAgICAgICAgICAgY29uZmlnLm9uTm9kZUNyZWF0ZWQobm9kZSwgY29uZmlnLCBtb2R1bGVOYW1lLCB1cmwpOwogICAgICAgICAgICB9CgogICAgICAgICAgICBub2RlLnNldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlY29udGV4dCcsIGNvbnRleHQuY29udGV4dE5hbWUpOwogICAgICAgICAgICBub2RlLnNldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlbW9kdWxlJywgbW9kdWxlTmFtZSk7CgogICAgICAgICAgICAvL1NldCB1cCBsb2FkIGxpc3RlbmVyLiBUZXN0IGF0dGFjaEV2ZW50IGZpcnN0IGJlY2F1c2UgSUU5IGhhcwogICAgICAgICAgICAvL2Egc3VidGxlIGlzc3VlIGluIGl0cyBhZGRFdmVudExpc3RlbmVyIGFuZCBzY3JpcHQgb25sb2FkIGZpcmluZ3MKICAgICAgICAgICAgLy90aGF0IGRvIG5vdCBtYXRjaCB0aGUgYmVoYXZpb3Igb2YgYWxsIG90aGVyIGJyb3dzZXJzIHdpdGgKICAgICAgICAgICAgLy9hZGRFdmVudExpc3RlbmVyIHN1cHBvcnQsIHdoaWNoIGZpcmUgdGhlIG9ubG9hZCBldmVudCBmb3IgYQogICAgICAgICAgICAvL3NjcmlwdCByaWdodCBhZnRlciB0aGUgc2NyaXB0IGV4ZWN1dGlvbi4gU2VlOgogICAgICAgICAgICAvL2h0dHBzOi8vY29ubmVjdC5taWNyb3NvZnQuY29tL0lFL2ZlZWRiYWNrL2RldGFpbHMvNjQ4MDU3L3NjcmlwdC1vbmxvYWQtZXZlbnQtaXMtbm90LWZpcmVkLWltbWVkaWF0ZWx5LWFmdGVyLXNjcmlwdC1leGVjdXRpb24KICAgICAgICAgICAgLy9VTkZPUlRVTkFURUxZIE9wZXJhIGltcGxlbWVudHMgYXR0YWNoRXZlbnQgYnV0IGRvZXMgbm90IGZvbGxvdyB0aGUgc2NyaXB0CiAgICAgICAgICAgIC8vc2NyaXB0IGV4ZWN1dGlvbiBtb2RlLgogICAgICAgICAgICBpZiAobm9kZS5hdHRhY2hFdmVudCAmJgogICAgICAgICAgICAgICAgICAgIC8vQ2hlY2sgaWYgbm9kZS5hdHRhY2hFdmVudCBpcyBhcnRpZmljaWFsbHkgYWRkZWQgYnkgY3VzdG9tIHNjcmlwdCBvcgogICAgICAgICAgICAgICAgICAgIC8vbmF0aXZlbHkgc3VwcG9ydGVkIGJ5IGJyb3dzZXIKICAgICAgICAgICAgICAgICAgICAvL3JlYWQgaHR0cHM6Ly9naXRodWIuY29tL2pyYnVya2UvcmVxdWlyZWpzL2lzc3Vlcy8xODcKICAgICAgICAgICAgICAgICAgICAvL2lmIHdlIGNhbiBOT1QgZmluZCBbbmF0aXZlIGNvZGVdIHRoZW4gaXQgbXVzdCBOT1QgbmF0aXZlbHkgc3VwcG9ydGVkLgogICAgICAgICAgICAgICAgICAgIC8vaW4gSUU4LCBub2RlLmF0dGFjaEV2ZW50IGRvZXMgbm90IGhhdmUgdG9TdHJpbmcoKQogICAgICAgICAgICAgICAgICAgIC8vTm90ZSB0aGUgdGVzdCBmb3IgIltuYXRpdmUgY29kZSIgd2l0aCBubyBjbG9zaW5nIGJyYWNlLCBzZWU6CiAgICAgICAgICAgICAgICAgICAgLy9odHRwczovL2dpdGh1Yi5jb20vanJidXJrZS9yZXF1aXJlanMvaXNzdWVzLzI3MwogICAgICAgICAgICAgICAgICAgICEobm9kZS5hdHRhY2hFdmVudC50b1N0cmluZyAmJiBub2RlLmF0dGFjaEV2ZW50LnRvU3RyaW5nKCkuaW5kZXhPZignW25hdGl2ZSBjb2RlJykgPCAwKSAmJgogICAgICAgICAgICAgICAgICAgICFpc09wZXJhKSB7CiAgICAgICAgICAgICAgICAvL1Byb2JhYmx5IElFLiBJRSAoYXQgbGVhc3QgNi04KSBkbyBub3QgZmlyZQogICAgICAgICAgICAgICAgLy9zY3JpcHQgb25sb2FkIHJpZ2h0IGFmdGVyIGV4ZWN1dGluZyB0aGUgc2NyaXB0LCBzbwogICAgICAgICAgICAgICAgLy93ZSBjYW5ub3QgdGllIHRoZSBhbm9ueW1vdXMgZGVmaW5lIGNhbGwgdG8gYSBuYW1lLgogICAgICAgICAgICAgICAgLy9Ib3dldmVyLCBJRSByZXBvcnRzIHRoZSBzY3JpcHQgYXMgYmVpbmcgaW4gJ2ludGVyYWN0aXZlJwogICAgICAgICAgICAgICAgLy9yZWFkeVN0YXRlIGF0IHRoZSB0aW1lIG9mIHRoZSBkZWZpbmUgY2FsbC4KICAgICAgICAgICAgICAgIHVzZUludGVyYWN0aXZlID0gdHJ1ZTsKCiAgICAgICAgICAgICAgICBub2RlLmF0dGFjaEV2ZW50KCdvbnJlYWR5c3RhdGVjaGFuZ2UnLCBjb250ZXh0Lm9uU2NyaXB0TG9hZCk7CiAgICAgICAgICAgICAgICAvL0l0IHdvdWxkIGJlIGdyZWF0IHRvIGFkZCBhbiBlcnJvciBoYW5kbGVyIGhlcmUgdG8gY2F0Y2gKICAgICAgICAgICAgICAgIC8vNDA0cyBpbiBJRTkrLiBIb3dldmVyLCBvbnJlYWR5c3RhdGVjaGFuZ2Ugd2lsbCBmaXJlIGJlZm9yZQogICAgICAgICAgICAgICAgLy90aGUgZXJyb3IgaGFuZGxlciwgc28gdGhhdCBkb2VzIG5vdCBoZWxwLiBJZiBhZGRFdmVudExpc3RlbmVyCiAgICAgICAgICAgICAgICAvL2lzIHVzZWQsIHRoZW4gSUUgd2lsbCBmaXJlIGVycm9yIGJlZm9yZSBsb2FkLCBidXQgd2UgY2Fubm90CiAgICAgICAgICAgICAgICAvL3VzZSB0aGF0IHBhdGh3YXkgZ2l2ZW4gdGhlIGNvbm5lY3QubWljcm9zb2Z0LmNvbSBpc3N1ZQogICAgICAgICAgICAgICAgLy9tZW50aW9uZWQgYWJvdmUgYWJvdXQgbm90IGRvaW5nIHRoZSAnc2NyaXB0IGV4ZWN1dGUsCiAgICAgICAgICAgICAgICAvL3RoZW4gZmlyZSB0aGUgc2NyaXB0IGxvYWQgZXZlbnQgbGlzdGVuZXIgYmVmb3JlIGV4ZWN1dGUKICAgICAgICAgICAgICAgIC8vbmV4dCBzY3JpcHQnIHRoYXQgb3RoZXIgYnJvd3NlcnMgZG8uCiAgICAgICAgICAgICAgICAvL0Jlc3QgaG9wZTogSUUxMCBmaXhlcyB0aGUgaXNzdWVzLAogICAgICAgICAgICAgICAgLy9hbmQgdGhlbiBkZXN0cm95cyBhbGwgaW5zdGFsbHMgb2YgSUUgNi05LgogICAgICAgICAgICAgICAgLy9ub2RlLmF0dGFjaEV2ZW50KCdvbmVycm9yJywgY29udGV4dC5vblNjcmlwdEVycm9yKTsKICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgIG5vZGUuYWRkRXZlbnRMaXN0ZW5lcignbG9hZCcsIGNvbnRleHQub25TY3JpcHRMb2FkLCBmYWxzZSk7CiAgICAgICAgICAgICAgICBub2RlLmFkZEV2ZW50TGlzdGVuZXIoJ2Vycm9yJywgY29udGV4dC5vblNjcmlwdEVycm9yLCBmYWxzZSk7CiAgICAgICAgICAgIH0KICAgICAgICAgICAgbm9kZS5zcmMgPSB1cmw7CgogICAgICAgICAgICAvL0ZvciBzb21lIGNhY2hlIGNhc2VzIGluIElFIDYtOCwgdGhlIHNjcmlwdCBleGVjdXRlcyBiZWZvcmUgdGhlIGVuZAogICAgICAgICAgICAvL29mIHRoZSBhcHBlbmRDaGlsZCBleGVjdXRpb24sIHNvIHRvIHRpZSBhbiBhbm9ueW1vdXMgZGVmaW5lCiAgICAgICAgICAgIC8vY2FsbCB0byB0aGUgbW9kdWxlIG5hbWUgKHdoaWNoIGlzIHN0b3JlZCBvbiB0aGUgbm9kZSksIGhvbGQgb24KICAgICAgICAgICAgLy90byBhIHJlZmVyZW5jZSB0byB0aGlzIG5vZGUsIGJ1dCBjbGVhciBhZnRlciB0aGUgRE9NIGluc2VydGlvbi4KICAgICAgICAgICAgY3VycmVudGx5QWRkaW5nU2NyaXB0ID0gbm9kZTsKICAgICAgICAgICAgaWYgKGJhc2VFbGVtZW50KSB7CiAgICAgICAgICAgICAgICBoZWFkLmluc2VydEJlZm9yZShub2RlLCBiYXNlRWxlbWVudCk7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBoZWFkLmFwcGVuZENoaWxkKG5vZGUpOwogICAgICAgICAgICB9CiAgICAgICAgICAgIGN1cnJlbnRseUFkZGluZ1NjcmlwdCA9IG51bGw7CgogICAgICAgICAgICByZXR1cm4gbm9kZTsKICAgICAgICB9IGVsc2UgaWYgKGlzV2ViV29ya2VyKSB7CiAgICAgICAgICAgIHRyeSB7CiAgICAgICAgICAgICAgICAvL0luIGEgd2ViIHdvcmtlciwgdXNlIGltcG9ydFNjcmlwdHMuIFRoaXMgaXMgbm90IGEgdmVyeQogICAgICAgICAgICAgICAgLy9lZmZpY2llbnQgdXNlIG9mIGltcG9ydFNjcmlwdHMsIGltcG9ydFNjcmlwdHMgd2lsbCBibG9jayB1bnRpbAogICAgICAgICAgICAgICAgLy9pdHMgc2NyaXB0IGlzIGRvd25sb2FkZWQgYW5kIGV2YWx1YXRlZC4gSG93ZXZlciwgaWYgd2ViIHdvcmtlcnMKICAgICAgICAgICAgICAgIC8vYXJlIGluIHBsYXksIHRoZSBleHBlY3RhdGlvbiBpcyB0aGF0IGEgYnVpbGQgaGFzIGJlZW4gZG9uZSBzbwogICAgICAgICAgICAgICAgLy90aGF0IG9ubHkgb25lIHNjcmlwdCBuZWVkcyB0byBiZSBsb2FkZWQgYW55d2F5LiBUaGlzIG1heSBuZWVkCiAgICAgICAgICAgICAgICAvL3RvIGJlIHJlZXZhbHVhdGVkIGlmIG90aGVyIHVzZSBjYXNlcyBiZWNvbWUgY29tbW9uLgogICAgICAgICAgICAgICAgaW1wb3J0U2NyaXB0cyh1cmwpOwoKICAgICAgICAgICAgICAgIC8vQWNjb3VudCBmb3IgYW5vbnltb3VzIG1vZHVsZXMKICAgICAgICAgICAgICAgIGNvbnRleHQuY29tcGxldGVMb2FkKG1vZHVsZU5hbWUpOwogICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICBjb250ZXh0Lm9uRXJyb3IobWFrZUVycm9yKCdpbXBvcnRzY3JpcHRzJywKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnaW1wb3J0U2NyaXB0cyBmYWlsZWQgZm9yICcgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVOYW1lICsgJyBhdCAnICsgdXJsLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgW21vZHVsZU5hbWVdKSk7CiAgICAgICAgICAgIH0KICAgICAgICB9CiAgICB9OwoKICAgIGZ1bmN0aW9uIGdldEludGVyYWN0aXZlU2NyaXB0KCkgewogICAgICAgIGlmIChpbnRlcmFjdGl2ZVNjcmlwdCAmJiBpbnRlcmFjdGl2ZVNjcmlwdC5yZWFkeVN0YXRlID09PSAnaW50ZXJhY3RpdmUnKSB7CiAgICAgICAgICAgIHJldHVybiBpbnRlcmFjdGl2ZVNjcmlwdDsKICAgICAgICB9CgogICAgICAgIGVhY2hSZXZlcnNlKHNjcmlwdHMoKSwgZnVuY3Rpb24gKHNjcmlwdCkgewogICAgICAgICAgICBpZiAoc2NyaXB0LnJlYWR5U3RhdGUgPT09ICdpbnRlcmFjdGl2ZScpIHsKICAgICAgICAgICAgICAgIHJldHVybiAoaW50ZXJhY3RpdmVTY3JpcHQgPSBzY3JpcHQpOwogICAgICAgICAgICB9CiAgICAgICAgfSk7CiAgICAgICAgcmV0dXJuIGludGVyYWN0aXZlU2NyaXB0OwogICAgfQoKICAgIC8vTG9vayBmb3IgYSBkYXRhLW1haW4gc2NyaXB0IGF0dHJpYnV0ZSwgd2hpY2ggY291bGQgYWxzbyBhZGp1c3QgdGhlIGJhc2VVcmwuCiAgICBpZiAoaXNCcm93c2VyICYmICFjZmcuc2tpcERhdGFNYWluKSB7CiAgICAgICAgLy9GaWd1cmUgb3V0IGJhc2VVcmwuIEdldCBpdCBmcm9tIHRoZSBzY3JpcHQgdGFnIHdpdGggcmVxdWlyZS5qcyBpbiBpdC4KICAgICAgICBlYWNoUmV2ZXJzZShzY3JpcHRzKCksIGZ1bmN0aW9uIChzY3JpcHQpIHsKICAgICAgICAgICAgLy9TZXQgdGhlICdoZWFkJyB3aGVyZSB3ZSBjYW4gYXBwZW5kIGNoaWxkcmVuIGJ5CiAgICAgICAgICAgIC8vdXNpbmcgdGhlIHNjcmlwdCdzIHBhcmVudC4KICAgICAgICAgICAgaWYgKCFoZWFkKSB7CiAgICAgICAgICAgICAgICBoZWFkID0gc2NyaXB0LnBhcmVudE5vZGU7CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIC8vTG9vayBmb3IgYSBkYXRhLW1haW4gYXR0cmlidXRlIHRvIHNldCBtYWluIHNjcmlwdCBmb3IgdGhlIHBhZ2UKICAgICAgICAgICAgLy90byBsb2FkLiBJZiBpdCBpcyB0aGVyZSwgdGhlIHBhdGggdG8gZGF0YSBtYWluIGJlY29tZXMgdGhlCiAgICAgICAgICAgIC8vYmFzZVVybCwgaWYgaXQgaXMgbm90IGFscmVhZHkgc2V0LgogICAgICAgICAgICBkYXRhTWFpbiA9IHNjcmlwdC5nZXRBdHRyaWJ1dGUoJ2RhdGEtbWFpbicpOwogICAgICAgICAgICBpZiAoZGF0YU1haW4pIHsKICAgICAgICAgICAgICAgIC8vUHJlc2VydmUgZGF0YU1haW4gaW4gY2FzZSBpdCBpcyBhIHBhdGggKGkuZS4gY29udGFpbnMgJz8nKQogICAgICAgICAgICAgICAgbWFpblNjcmlwdCA9IGRhdGFNYWluOwoKICAgICAgICAgICAgICAgIC8vU2V0IGZpbmFsIGJhc2VVcmwgaWYgdGhlcmUgaXMgbm90IGFscmVhZHkgYW4gZXhwbGljaXQgb25lLgogICAgICAgICAgICAgICAgaWYgKCFjZmcuYmFzZVVybCkgewogICAgICAgICAgICAgICAgICAgIC8vUHVsbCBvZmYgdGhlIGRpcmVjdG9yeSBvZiBkYXRhLW1haW4gZm9yIHVzZSBhcyB0aGUKICAgICAgICAgICAgICAgICAgICAvL2Jhc2VVcmwuCiAgICAgICAgICAgICAgICAgICAgc3JjID0gbWFpblNjcmlwdC5zcGxpdCgnLycpOwogICAgICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBzcmMucG9wKCk7CiAgICAgICAgICAgICAgICAgICAgc3ViUGF0aCA9IHNyYy5sZW5ndGggPyBzcmMuam9pbignLycpICArICcvJyA6ICcuLyc7CgogICAgICAgICAgICAgICAgICAgIGNmZy5iYXNlVXJsID0gc3ViUGF0aDsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL1N0cmlwIG9mZiBhbnkgdHJhaWxpbmcgLmpzIHNpbmNlIG1haW5TY3JpcHQgaXMgbm93CiAgICAgICAgICAgICAgICAvL2xpa2UgYSBtb2R1bGUgbmFtZS4KICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBtYWluU2NyaXB0LnJlcGxhY2UoanNTdWZmaXhSZWdFeHAsICcnKTsKCiAgICAgICAgICAgICAgICAvL0lmIG1haW5TY3JpcHQgaXMgc3RpbGwgYSBwYXRoLCBmYWxsIGJhY2sgdG8gZGF0YU1haW4KICAgICAgICAgICAgICAgIGlmIChyZXEuanNFeHRSZWdFeHAudGVzdChtYWluU2NyaXB0KSkgewogICAgICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBkYXRhTWFpbjsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL1B1dCB0aGUgZGF0YS1tYWluIHNjcmlwdCBpbiB0aGUgZmlsZXMgdG8gbG9hZC4KICAgICAgICAgICAgICAgIGNmZy5kZXBzID0gY2ZnLmRlcHMgPyBjZmcuZGVwcy5jb25jYXQobWFpblNjcmlwdCkgOiBbbWFpblNjcmlwdF07CgogICAgICAgICAgICAgICAgcmV0dXJuIHRydWU7CiAgICAgICAgICAgIH0KICAgICAgICB9KTsKICAgIH0KCiAgICAvKioKICAgICAqIFRoZSBmdW5jdGlvbiB0aGF0IGhhbmRsZXMgZGVmaW5pdGlvbnMgb2YgbW9kdWxlcy4gRGlmZmVycyBmcm9tCiAgICAgKiByZXF1aXJlKCkgaW4gdGhhdCBhIHN0cmluZyBmb3IgdGhlIG1vZHVsZSBzaG91bGQgYmUgdGhlIGZpcnN0IGFyZ3VtZW50LAogICAgICogYW5kIHRoZSBmdW5jdGlvbiB0byBleGVjdXRlIGFmdGVyIGRlcGVuZGVuY2llcyBhcmUgbG9hZGVkIHNob3VsZAogICAgICogcmV0dXJuIGEgdmFsdWUgdG8gZGVmaW5lIHRoZSBtb2R1bGUgY29ycmVzcG9uZGluZyB0byB0aGUgZmlyc3QgYXJndW1lbnQncwogICAgICogbmFtZS4KICAgICAqLwogICAgZGVmaW5lID0gZnVuY3Rpb24gKG5hbWUsIGRlcHMsIGNhbGxiYWNrKSB7CiAgICAgICAgdmFyIG5vZGUsIGNvbnRleHQ7CgogICAgICAgIC8vQWxsb3cgZm9yIGFub255bW91cyBtb2R1bGVzCiAgICAgICAgaWYgKHR5cGVvZiBuYW1lICE9PSAnc3RyaW5nJykgewogICAgICAgICAgICAvL0FkanVzdCBhcmdzIGFwcHJvcHJpYXRlbHkKICAgICAgICAgICAgY2FsbGJhY2sgPSBkZXBzOwogICAgICAgICAgICBkZXBzID0gbmFtZTsKICAgICAgICAgICAgbmFtZSA9IG51bGw7CiAgICAgICAgfQoKICAgICAgICAvL1RoaXMgbW9kdWxlIG1heSBub3QgaGF2ZSBkZXBlbmRlbmNpZXMKICAgICAgICBpZiAoIWlzQXJyYXkoZGVwcykpIHsKICAgICAgICAgICAgY2FsbGJhY2sgPSBkZXBzOwogICAgICAgICAgICBkZXBzID0gbnVsbDsKICAgICAgICB9CgogICAgICAgIC8vSWYgbm8gbmFtZSwgYW5kIGNhbGxiYWNrIGlzIGEgZnVuY3Rpb24sIHRoZW4gZmlndXJlIG91dCBpZiBpdCBhCiAgICAgICAgLy9Db21tb25KUyB0aGluZyB3aXRoIGRlcGVuZGVuY2llcy4KICAgICAgICBpZiAoIWRlcHMgJiYgaXNGdW5jdGlvbihjYWxsYmFjaykpIHsKICAgICAgICAgICAgZGVwcyA9IFtdOwogICAgICAgICAgICAvL1JlbW92ZSBjb21tZW50cyBmcm9tIHRoZSBjYWxsYmFjayBzdHJpbmcsCiAgICAgICAgICAgIC8vbG9vayBmb3IgcmVxdWlyZSBjYWxscywgYW5kIHB1bGwgdGhlbSBpbnRvIHRoZSBkZXBlbmRlbmNpZXMsCiAgICAgICAgICAgIC8vYnV0IG9ubHkgaWYgdGhlcmUgYXJlIGZ1bmN0aW9uIGFyZ3MuCiAgICAgICAgICAgIGlmIChjYWxsYmFjay5sZW5ndGgpIHsKICAgICAgICAgICAgICAgIGNhbGxiYWNrCiAgICAgICAgICAgICAgICAgICAgLnRvU3RyaW5nKCkKICAgICAgICAgICAgICAgICAgICAucmVwbGFjZShjb21tZW50UmVnRXhwLCAnJykKICAgICAgICAgICAgICAgICAgICAucmVwbGFjZShjanNSZXF1aXJlUmVnRXhwLCBmdW5jdGlvbiAobWF0Y2gsIGRlcCkgewogICAgICAgICAgICAgICAgICAgICAgICBkZXBzLnB1c2goZGVwKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAvL01heSBiZSBhIENvbW1vbkpTIHRoaW5nIGV2ZW4gd2l0aG91dCByZXF1aXJlIGNhbGxzLCBidXQgc3RpbGwKICAgICAgICAgICAgICAgIC8vY291bGQgdXNlIGV4cG9ydHMsIGFuZCBtb2R1bGUuIEF2b2lkIGRvaW5nIGV4cG9ydHMgYW5kIG1vZHVsZQogICAgICAgICAgICAgICAgLy93b3JrIHRob3VnaCBpZiBpdCBqdXN0IG5lZWRzIHJlcXVpcmUuCiAgICAgICAgICAgICAgICAvL1JFUVVJUkVTIHRoZSBmdW5jdGlvbiB0byBleHBlY3QgdGhlIENvbW1vbkpTIHZhcmlhYmxlcyBpbiB0aGUKICAgICAgICAgICAgICAgIC8vb3JkZXIgbGlzdGVkIGJlbG93LgogICAgICAgICAgICAgICAgZGVwcyA9IChjYWxsYmFjay5sZW5ndGggPT09IDEgPyBbJ3JlcXVpcmUnXSA6IFsncmVxdWlyZScsICdleHBvcnRzJywgJ21vZHVsZSddKS5jb25jYXQoZGVwcyk7CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIC8vSWYgaW4gSUUgNi04IGFuZCBoaXQgYW4gYW5vbnltb3VzIGRlZmluZSgpIGNhbGwsIGRvIHRoZSBpbnRlcmFjdGl2ZQogICAgICAgIC8vd29yay4KICAgICAgICBpZiAodXNlSW50ZXJhY3RpdmUpIHsKICAgICAgICAgICAgbm9kZSA9IGN1cnJlbnRseUFkZGluZ1NjcmlwdCB8fCBnZXRJbnRlcmFjdGl2ZVNjcmlwdCgpOwogICAgICAgICAgICBpZiAobm9kZSkgewogICAgICAgICAgICAgICAgaWYgKCFuYW1lKSB7CiAgICAgICAgICAgICAgICAgICAgbmFtZSA9IG5vZGUuZ2V0QXR0cmlidXRlKCdkYXRhLXJlcXVpcmVtb2R1bGUnKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNvbnRleHQgPSBjb250ZXh0c1tub2RlLmdldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlY29udGV4dCcpXTsKICAgICAgICAgICAgfQogICAgICAgIH0KCiAgICAgICAgLy9BbHdheXMgc2F2ZSBvZmYgZXZhbHVhdGluZyB0aGUgZGVmIGNhbGwgdW50aWwgdGhlIHNjcmlwdCBvbmxvYWQgaGFuZGxlci4KICAgICAgICAvL1RoaXMgYWxsb3dzIG11bHRpcGxlIG1vZHVsZXMgdG8gYmUgaW4gYSBmaWxlIHdpdGhvdXQgcHJlbWF0dXJlbHkKICAgICAgICAvL3RyYWNpbmcgZGVwZW5kZW5jaWVzLCBhbmQgYWxsb3dzIGZvciBhbm9ueW1vdXMgbW9kdWxlIHN1cHBvcnQsCiAgICAgICAgLy93aGVyZSB0aGUgbW9kdWxlIG5hbWUgaXMgbm90IGtub3duIHVudGlsIHRoZSBzY3JpcHQgb25sb2FkIGV2ZW50CiAgICAgICAgLy9vY2N1cnMuIElmIG5vIGNvbnRleHQsIHVzZSB0aGUgZ2xvYmFsIHF1ZXVlLCBhbmQgZ2V0IGl0IHByb2Nlc3NlZAogICAgICAgIC8vaW4gdGhlIG9uc2NyaXB0IGxvYWQgY2FsbGJhY2suCiAgICAgICAgaWYgKGNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZS5wdXNoKFtuYW1lLCBkZXBzLCBjYWxsYmFja10pOwogICAgICAgICAgICBjb250ZXh0LmRlZlF1ZXVlTWFwW25hbWVdID0gdHJ1ZTsKICAgICAgICB9IGVsc2UgewogICAgICAgICAgICBnbG9iYWxEZWZRdWV1ZS5wdXNoKFtuYW1lLCBkZXBzLCBjYWxsYmFja10pOwogICAgICAgIH0KICAgIH07CgogICAgZGVmaW5lLmFtZCA9IHsKICAgICAgICBqUXVlcnk6IHRydWUKICAgIH07CgogICAgLyoqCiAgICAgKiBFeGVjdXRlcyB0aGUgdGV4dC4gTm9ybWFsbHkganVzdCB1c2VzIGV2YWwsIGJ1dCBjYW4gYmUgbW9kaWZpZWQKICAgICAqIHRvIHVzZSBhIGJldHRlciwgZW52aXJvbm1lbnQtc3BlY2lmaWMgY2FsbC4gT25seSB1c2VkIGZvciB0cmFuc3BpbGluZwogICAgICogbG9hZGVyIHBsdWdpbnMsIG5vdCBmb3IgcGxhaW4gSlMgbW9kdWxlcy4KICAgICAqIEBwYXJhbSB7U3RyaW5nfSB0ZXh0IHRoZSB0ZXh0IHRvIGV4ZWN1dGUvZXZhbHVhdGUuCiAgICAgKi8KICAgIHJlcS5leGVjID0gZnVuY3Rpb24gKHRleHQpIHsKICAgICAgICAvKmpzbGludCBldmlsOiB0cnVlICovCiAgICAgICAgcmV0dXJuIGV2YWwodGV4dCk7CiAgICB9OwoKICAgIC8vU2V0IHVwIHdpdGggY29uZmlnIGluZm8uCiAgICByZXEoY2ZnKTsKfSh0aGlzKSk7Cg==", @@ -897,6 +857,7 @@ } } }, + "cell_type": "code", "source": [ "# Convert inputs and outputs to subwords\n", "inp_text = to_tokens(encoders[\"inputs\"].encode(inputs))\n", @@ -904,7 +865,7 @@ "\n", "# Run eval to collect attention weights\n", "example = encode_eval(inputs, outputs)\n", - "with tfe.restore_variables_on_create(ckpt_path):\n", + "with tfe.restore_variables_on_create(tf.train.latest_checkpoint(checkpoint_dir)):\n", " translate_model.set_mode(Modes.EVAL)\n", " translate_model(example)\n", "# Get normalized attention weights for each layer\n", @@ -913,8 +874,7 @@ "call_html()\n", "attention.show(inp_text, out_text, enc_atts, dec_atts, encdec_atts)" ], - "cell_type": "code", - "execution_count": 16, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -1376,10 +1336,10 @@ "id": "i7BZuO7T5BB4", "colab_type": "text" }, + "cell_type": "markdown", "source": [ "# Train a custom model on MNIST" - ], - "cell_type": "markdown" + ] }, { "metadata": { @@ -1392,6 +1352,7 @@ } } }, + "cell_type": "code", "source": [ "# Create your own model\n", "\n", @@ -1411,7 +1372,6 @@ "hparams.hidden_size = 64\n", "model = MySimpleModel(hparams, Modes.TRAIN)" ], - "cell_type": "code", "execution_count": 0, "outputs": [] }, @@ -1424,11 +1384,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 1 - } - ], "base_uri": "https://localhost:8080/", "height": 34 }, @@ -1445,6 +1400,7 @@ } } }, + "cell_type": "code", "source": [ "# Prepare for the training loop\n", "\n", @@ -1462,8 +1418,7 @@ "\n", "optimizer = tf.train.AdamOptimizer()" ], - "cell_type": "code", - "execution_count": 42, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -1483,11 +1438,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 11 - } - ], "base_uri": "https://localhost:8080/", "height": 204 }, @@ -1504,6 +1454,7 @@ } } }, + "cell_type": "code", "source": [ "# Train\n", "NUM_STEPS = 500\n", @@ -1518,8 +1469,7 @@ " if count >= NUM_STEPS:\n", " break" ], - "cell_type": "code", - "execution_count": 46, + "execution_count": 0, "outputs": [ { "output_type": "stream", @@ -1549,11 +1499,6 @@ "startup": false, "wait_interval": 0 }, - "output_extras": [ - { - "item_id": 2 - } - ], "base_uri": "https://localhost:8080/", "height": 68 }, @@ -1570,6 +1515,7 @@ } } }, + "cell_type": "code", "source": [ "model.set_mode(Modes.EVAL)\n", "mnist_eval_dataset = mnist_problem.dataset(Modes.EVAL, data_dir)\n", @@ -1597,8 +1543,7 @@ "for name, val in metrics_result().items():\n", " print(\"%s: %.2f\" % (name, val))" ], - "cell_type": "code", - "execution_count": 47, + "execution_count": 0, "outputs": [ { "output_type": "stream", diff --git a/tensor2tensor/rl/envs/simulated_batch_env.py b/tensor2tensor/rl/envs/simulated_batch_env.py index 20fe868f5..9a229b424 100644 --- a/tensor2tensor/rl/envs/simulated_batch_env.py +++ b/tensor2tensor/rl/envs/simulated_batch_env.py @@ -100,7 +100,8 @@ def simulate(self, action): action = tf.concat([action, action], axis=0) inputs = {"inputs": tf.expand_dims(inputs_merged, axis=0), # Add batch. "input_action": tf.expand_dims(action, axis=0)} - model_output = self._model.infer(inputs) + with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): + model_output = self._model.infer(inputs) observ = model_output["targets"] observ = tf.cast(observ[:, 0, :, :, :], tf.float32) reward = model_output["target_reward"][:, 0, 0, 0] - 1 diff --git a/tensor2tensor/rl/envs/tf_atari_wrappers.py b/tensor2tensor/rl/envs/tf_atari_wrappers.py index 9c0af1ae3..83b9a9ae7 100644 --- a/tensor2tensor/rl/envs/tf_atari_wrappers.py +++ b/tensor2tensor/rl/envs/tf_atari_wrappers.py @@ -177,13 +177,11 @@ def __init__(self, batch_env): def simulate(self, action): with tf.name_scope("environment/simulate"): # Do we need this? - observ_copy = self._batch_env.observ.read_value() - with tf.control_dependencies([observ_copy]): - reward, done = self._batch_env.simulate(action) - encoded_image = tf.image.encode_png( - tf.cast(observ_copy[0, ...], tf.uint8)) - with tf.control_dependencies([reward, done]): - enqueue_op = self.speculum.enqueue( - [encoded_image, reward, action, done]) - with tf.control_dependencies([enqueue_op]): - return tf.identity(reward), tf.identity(done) + reward, done = self._batch_env.simulate(action) + encoded_image = tf.image.encode_png( + tf.cast(self._batch_env.observ[0, ...], tf.uint8)) + with tf.control_dependencies([reward, done]): + enqueue_op = self.speculum.enqueue( + [encoded_image, reward, action, done]) + with tf.control_dependencies([enqueue_op]): + return tf.identity(reward), tf.identity(done) diff --git a/tensor2tensor/rl/envs/utils.py b/tensor2tensor/rl/envs/utils.py index cbd4cbb60..26e12eab7 100644 --- a/tensor2tensor/rl/envs/utils.py +++ b/tensor2tensor/rl/envs/utils.py @@ -271,6 +271,7 @@ def _worker(self, constructor, conn): continue if message == self._CLOSE: assert payload is None + env.close() break raise KeyError("Received message of unknown type {}".format(message)) except Exception: # pylint: disable=broad-except diff --git a/tensor2tensor/rl/model_rl_experiment.py b/tensor2tensor/rl/model_rl_experiment.py index 87c07b26e..4fd89022e 100644 --- a/tensor2tensor/rl/model_rl_experiment.py +++ b/tensor2tensor/rl/model_rl_experiment.py @@ -67,7 +67,7 @@ def train(hparams, output_dir): FLAGS.output_dir = output_dir FLAGS.model = hparams.generative_model FLAGS.hparams_set = hparams.generative_model_params - FLAGS.train_steps = hparams.model_train_steps + FLAGS.train_steps = hparams.model_train_steps * (iloop + 2) FLAGS.eval_steps = 10 t2t_trainer.main([]) @@ -108,10 +108,10 @@ def train(hparams, output_dir): def main(_): hparams = tf.contrib.training.HParams( epochs=10, - true_env_generator_num_steps=5000, + true_env_generator_num_steps=10000, generative_model="basic_conv_gen", generative_model_params="basic_conv", - model_train_steps=15000, + model_train_steps=25000, simulated_env_generator_num_steps=300, ppo_epochs_num=200, ppo_epoch_length=300, diff --git a/tensor2tensor/rl/rl_trainer_lib.py b/tensor2tensor/rl/rl_trainer_lib.py index d244bcf2b..4ede6438e 100644 --- a/tensor2tensor/rl/rl_trainer_lib.py +++ b/tensor2tensor/rl/rl_trainer_lib.py @@ -56,11 +56,10 @@ def define_train(hparams, environment_spec, event_dir): "network", functools.partial(policy_lambda, batch_env.action_space, hparams)) - with tf.variable_scope("", reuse=tf.AUTO_REUSE): - memory, collect_summary = collect.define_collect( - policy_factory, batch_env, hparams, eval_phase=False) - ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams) - summary = tf.summary.merge([collect_summary, ppo_summary]) + memory, collect_summary = collect.define_collect( + policy_factory, batch_env, hparams, eval_phase=False) + ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams) + summary = tf.summary.merge([collect_summary, ppo_summary]) with tf.variable_scope("eval", reuse=tf.AUTO_REUSE): eval_env_lambda = env_lambda diff --git a/tensor2tensor/utils/adafactor.py b/tensor2tensor/utils/adafactor.py index 2d328d4ba..6b70fa057 100644 --- a/tensor2tensor/utils/adafactor.py +++ b/tensor2tensor/utils/adafactor.py @@ -34,7 +34,9 @@ class AdafactorOptimizer(tf.train.Optimizer): parameters to maintain the second-moment estimator, instead of AB. This is advantageous on memory-limited systems. In addition, beta1 (momentum) is set to zero by default, saving an additional auxiliary - parameter per weight. + parameter per weight. Variables with >=3 dimensions are treated as + collections of two-dimensional matrices - factorization is over the final + two dimensions. 2. Adafactor incorporates "update-clipping" - a scale-invariant analog of gradient clipping. This adds stability @@ -62,7 +64,7 @@ class AdafactorOptimizer(tf.train.Optimizer): if var is 2-dimensional: v_r <- zeros([num_rows]) v_c <- zeros([num_cols]) - else: + if var is 0-dimensional or 1-dimensional: v <- zeros(shape(var)) ``` @@ -74,10 +76,13 @@ class AdafactorOptimizer(tf.train.Optimizer): v_r <- decay_rate * v_r + (1 - decay_rate) * reduce_mean(grad_squared, 1) v_c <- decay_rate * v_c + (1 - decay_rate) * reduce_mean(grad_squared, 0) v = outer_prod(v_r, v_c) / reduce_mean(v_r) - else: + if var is 0-dimensional or 1-dimensional: v <- decay_rate * v + (1 - decay_rate) * grad_squared ``` + For variables with >=3 dimensions, we factorize the second-moment accumulator + over the final 2 dimensions. See the code for details. + Several parts of this algorithm are configurable from the initializer. @@ -95,8 +100,6 @@ class AdafactorOptimizer(tf.train.Optimizer): factored: whether to factor the second-moment estimator. True means less memory usage. - TODO(noam): we should also apply the 2d logic to the two final dimensions. - of >2d convolutional kernels. """ def __init__(self, @@ -159,7 +162,7 @@ def _should_use_factored_second_moment_estimate(self, shape): Returns: a boolean """ - return self._factored and len(shape) == 2 + return self._factored and len(shape) >= 2 def _create_slots(self, var_list): for var in var_list: @@ -167,8 +170,8 @@ def _create_slots(self, var_list): if self._beta1: self._zeros_slot(var, "m", self._name) if self._should_use_factored_second_moment_estimate(shape): - r_val = tf.zeros([shape[0]], dtype=tf.float32) - c_val = tf.zeros([shape[1]], dtype=tf.float32) + r_val = tf.zeros(shape[:-1], dtype=tf.float32) + c_val = tf.zeros(shape[:-2] + shape[-1:], dtype=tf.float32) self._get_or_make_slot(var, r_val, "vr", self._name) self._get_or_make_slot(var, c_val, "vc", self._name) else: @@ -219,8 +222,8 @@ def _resource_apply_dense(self, grad, var): shape = var.get_shape().as_list() updates = [] if self._should_use_factored_second_moment_estimate(shape): - grad_squared_row_mean = tf.reduce_mean(grad_squared, 1) - grad_squared_col_mean = tf.reduce_mean(grad_squared, 0) + grad_squared_row_mean = tf.reduce_mean(grad_squared, -1) + grad_squared_col_mean = tf.reduce_mean(grad_squared, -2) vr = self.get_slot(var, "vr") new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) vc = self.get_slot(var, "vc") @@ -228,10 +231,10 @@ def _resource_apply_dense(self, grad, var): vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking) vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking) updates = [vr_update, vc_update] - long_term_mean = tf.reduce_mean(new_vr) + long_term_mean = tf.reduce_mean(new_vr, -1, keep_dims=True) r_factor = tf.rsqrt(new_vr / long_term_mean) c_factor = tf.rsqrt(new_vc) - x = grad * tf.expand_dims(r_factor, 1) * tf.expand_dims(c_factor, 0) + x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2) else: v = self.get_slot(var, "v") new_v = decay_rate * v + mixing_rate * grad_squared diff --git a/tensor2tensor/utils/cloud_tpu.py b/tensor2tensor/utils/cloud_tpu.py index ef78458a9..d1e0e875d 100644 --- a/tensor2tensor/utils/cloud_tpu.py +++ b/tensor2tensor/utils/cloud_tpu.py @@ -50,6 +50,7 @@ def __init__(self): def cleanup(self, current_vm_name=None, current_tpu_name=None, skip_confirmation=False): + """Delete old instances and cleanup old trainer and tunnel processes.""" process_pids = os.listdir(self._tmp_dir) for pid in process_pids: try: @@ -72,7 +73,7 @@ def cleanup(self, current_vm_name=None, current_tpu_name=None, del_tpu = False if info["delete_on_done"]: if (info["vm_name"] != current_vm_name and - info["vm_name"] in zip(*list_vm_names_and_ips())[0]): + info["vm_name"] in list(zip(*list_vm_names_and_ips()))[0]): print("Old VM %s found. Delete?" % info["vm_name"]) if skip_confirmation: del_vm = True @@ -80,7 +81,7 @@ def cleanup(self, current_vm_name=None, current_tpu_name=None, if confirm(): del_vm = True if (info["tpu_name"] != current_tpu_name and - info["tpu_name"] in zip(*list_tpu_names_and_ips())[0]): + info["tpu_name"] in list(zip(*list_tpu_names_and_ips()))[0]): print("Old TPU %s found. Delete?" % info["tpu_name"]) if skip_confirmation: del_tpu = True @@ -340,8 +341,8 @@ def create_vm_tpu_pair(vm_name, tpu_name, reuse_if_exists=True, vm_info = list_vm_names_and_ips() tpu_info = list_tpu_names_and_ips() - vm_names = zip(*vm_info)[0] if vm_info else [] - tpu_names = zip(*tpu_info)[0] if tpu_info else [] + vm_names = list(zip(*vm_info))[0] if vm_info else [] + tpu_names = list(zip(*tpu_info))[0] if tpu_info else [] make_vm = False vm_ip = None diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index 887c8a191..2856c76ad 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -51,8 +51,7 @@ def decode_hparams(overrides=""): max_input_size=-1, identity_output=False, num_samples=-1, - delimiter="\n", - force_decode_length=False) + delimiter="\n") hp = hp.parse(overrides) return hp diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py index 3cfa3592d..912f56169 100644 --- a/tensor2tensor/utils/expert_utils.py +++ b/tensor2tensor/utils/expert_utils.py @@ -30,34 +30,12 @@ import six from six.moves import range # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin -import tensorflow as tf - -from tensorflow.python.framework import function - -DEFAULT_DEV_STRING = "existing_device" +from tensor2tensor.layers import common_layers -@function.Defun( - python_grad_func=lambda x, dy: tf.convert_to_tensor(dy), - shape_func=lambda op: [op.inputs[0].get_shape()]) -def convert_gradient_to_tensor(x): - """Identity operation whose gradient is converted to a `Tensor`. - - Currently, the gradient to `tf.concat` is particularly expensive to - compute if dy is an `IndexedSlices` (a lack of GPU implementation - forces the gradient operation onto CPU). This situation occurs when - the output of the `tf.concat` is eventually passed to `tf.gather`. - It is sometimes faster to convert the gradient to a `Tensor`, so as - to get the cheaper gradient for `tf.concat`. To do this, replace - `tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`. - - Args: - x: A `Tensor`. +import tensorflow as tf - Returns: - The input `Tensor`. - """ - return x +DEFAULT_DEV_STRING = "existing_device" def add_scope(scope=None, scope_fn=None): @@ -476,7 +454,7 @@ def noisy_top_k_gating(x, noisy_logits = clean_logits + ( tf.random_normal(tf.shape(clean_logits)) * noise_stddev) logits = noisy_logits - if should_generate_summaries(): + if common_layers.should_generate_summaries(): tf.summary.histogram("noisy_logits", noisy_logits) tf.summary.histogram("noise_stddev", noise_stddev) else: @@ -495,7 +473,7 @@ def noisy_top_k_gating(x, k), 0) else: load = _gates_to_load(gates) - if should_generate_summaries(): + if common_layers.should_generate_summaries(): tf.summary.histogram("importance", tf.reduce_sum(gates, 0)) tf.summary.histogram("load", load) return gates, load @@ -672,7 +650,7 @@ class SparseDispatcher(object): The inputs and outputs are all two-dimensional [batch, depth]. Caller is responsible for collapsing additional dimensions prior to calling this class and reshaping the output to the original shape. - See reshape_like(). + See common_layers.reshape_like(). Example use: @@ -746,7 +724,8 @@ def combine(self, expert_out, multiply_by_gates=True): a `Tensor` with shape `[batch_size, ]`. """ # see comments on convert_gradient_to_tensor - stitched = convert_gradient_to_tensor(tf.concat(expert_out, 0)) + stitched = common_layers.convert_gradient_to_tensor( + tf.concat(expert_out, 0)) if multiply_by_gates: stitched *= tf.expand_dims(self._nonzero_gates, 1) combined = tf.unsorted_segment_sum(stitched, self._batch_index, @@ -821,8 +800,8 @@ def dispatch(self, inp): dispatched = self._dp(lambda a, b: a.dispatch(b), self._dispatchers, inp) ret = self._ep(tf.concat, transpose_list_of_lists(dispatched), 0) if ret[0].dtype == tf.float32: - # see comments on convert_gradient_to_tensor - ret = self._ep(convert_gradient_to_tensor, ret) + # see comments on common_layers.convert_gradient_to_tensor + ret = self._ep(common_layers.convert_gradient_to_tensor, ret) return ret def combine(self, expert_out, multiply_by_gates=True): @@ -846,7 +825,7 @@ def combine(self, expert_out, multiply_by_gates=True): expert_output_parts_t = transpose_list_of_lists(expert_output_parts) def my_combine(dispatcher, parts): return dispatcher.combine( - convert_gradient_to_tensor(tf.concat(parts, 0)), + common_layers.convert_gradient_to_tensor(tf.concat(parts, 0)), multiply_by_gates=multiply_by_gates) return self._dp(my_combine, self._dispatchers, expert_output_parts_t) @@ -904,14 +883,6 @@ def my_fn(x): return my_fn -def reshape_like(a, b): - """Reshapes a to match the shape of b in all but the last dimension.""" - ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0)) - if not tf.contrib.eager.in_eager_mode(): - ret.set_shape(b.get_shape().as_list()[:-1] + a.get_shape().as_list()[-1:]) - return ret - - def flatten_all_but_last(a): """Flatten all dimensions of a except the last.""" ret = tf.reshape(a, [-1, tf.shape(a)[-1]]) @@ -981,7 +952,7 @@ def distributed_moe(data_parallelism, expert_in = dispatcher.dispatch(xs_flat) expert_out = ep(expert_fn, expert_in) ys_flat = dispatcher.combine(expert_out) - ys = dp(reshape_like, ys_flat, xs) + ys = dp(common_layers.reshape_like, ys_flat, xs) # compute some load-balancing losses. load = tf.add_n(load) importance = tf.add_n(dp(tf.reduce_sum, gates, 0)) @@ -1054,7 +1025,7 @@ def local_moe(x, expert_outputs = ep(expert_fn, **expert_kwargs) y_flat = dispatcher.combine(expert_outputs) - y = reshape_like(y_flat, x) + y = common_layers.reshape_like(y_flat, x) importance = tf.reduce_sum(gates, 0) loss = loss_coef * (cv_squared(importance) + cv_squared(load)) @@ -1201,16 +1172,323 @@ def length_coordinate(self): return self._indices -def should_generate_summaries(): - """Is this an appropriate context to generate summaries. +def local_moe_tpu(inputs, + hidden_size, + output_size, + num_experts, + loss_coef=1e-3, + overhead=1.0): + """Local mixture of experts that works well on TPU. + + See https://arxiv.org/abs/1701.06538 + + There are num_experts expert networks, each containing a relu-activated + hidden layer of size hidden_size, followed by an output projection. + + The number of parameters is thus: + num_experts * (input_size * hidden_size + hidden_size * output_size) + + The input is 3d: [batch, length, depth], consisting of the representations + of all positions in a batch of sequences. + + Each position of each sequence is sent to 0-2 experts. The expert + choices and the combination weights are determined by a learned gating + function. + + This function returns a small auxiliary loss that should be added to the + training loss of the model. This loss helps to balance expert usage. + Without the loss, it is very likely that a few experts will be trained and + the rest will starve. + + Several hacks are necessary to get around current TPU limitations: + + - To ensure static shapes, we enforce (by truncation/padding) + that each sequence send the same number of elements to each expert. + + It would make more sense to enforce this equality over the entire batch, + as opposed to on individual sequences. This would allow more freedom + for individual sequences to be unbalanced. Unfortunately, that would + slow down our hacked-up gather-by-matmul implementation. + + TODO(noam): There is no real reason for a single sequence to be the unit + of equal allocation. Reshaping the inputs would allow us to pick a + different unit of equal allocation. + + TODO(noam): Factor this code better. We want to be able to substitute + different code for the experts themselves. We also want to integrate this + gating/dispatching logic into multi-device mixtures-of-experts. + + Args: + inputs: a Tensor with shape [batch, length, depth] + hidden_size: an integer + output_size: an integer + num_experts: an integer + loss_coef: a float scalar + overhead: multiplicative factor of how much spare capacity to assign + + Returns: + outputs: a Tensor with shape [batch, length, output_size] + loss: a scalar + """ + batch, length, input_size = common_layers.shape_list(inputs)[:] + # Each sequence sends expert_capacity positions to each expert. + if isinstance(length, int): + expert_capacity = min( + length, int((length * 2 * overhead) / num_experts)) + else: + expert_capacity = tf.minimum( + length, tf.to_int32( + tf.to_float(length) * 2 * overhead / num_experts)) + expert_capacity_f = tf.to_float(expert_capacity) + + # This is the learned gating function. + gates = tf.nn.softmax( + tf.to_float(common_layers.dense(inputs, num_experts, name="logits"))) + + # Find the top expert for each position. + gate_1, index_1 = common_layers.top_1_tpu(gates) + # [batch, length, num_experts] + mask_1 = tf.one_hot(index_1, num_experts) + # [batch, length, num_experts] + # This is the position within the expert's mini-batch for this sequence + position_in_expert_1 = common_layers.cumsum( + mask_1, axis=1, exclusive=True) * mask_1 + # Remove the elements that don't fit. + mask_1 *= tf.to_float(tf.less(position_in_expert_1, expert_capacity_f)) + # [batch, 1, num_experts] + # How many examples in this sequence go to this expert + mask_1_count = tf.reduce_sum(mask_1, axis=1, keep_dims=True) + # [batch, length] - mostly ones, but zeros where something didn't fit + mask_1_flat = tf.reduce_sum(mask_1, axis=2) + position_in_expert_1 = tf.reduce_sum(position_in_expert_1, axis=2) + # Weight assigned to first expert. + gate_1 *= mask_1_flat + + # Pick a second-place expert for each position. + # We first mask out the experts that we expect to be over-capacity + space_remaining = expert_capacity_f - mask_1_count + use_rate = (mask_1_count + 1.0) / tf.to_float(length) + # At what point in the sequence do we expect the expert to be full. + expected_exhaustion_pos = space_remaining / use_rate + # A Tensor with shape [batch, length, num_experts] representing a boolean + # - whether we expect that the expert will already be full. + expected_exhausted = tf.to_float(tf.greater( + tf.reshape(tf.to_float(tf.range(length)), [1, length, 1]), + expected_exhaustion_pos)) + masked_gates = gates - mask_1 - expected_exhausted + # This section is similar to the section above. + gate_2, index_2 = common_layers.top_1_tpu(masked_gates) + # [batch, length, num_experts] + mask_2 = tf.one_hot(index_2, num_experts) + position_in_expert_2 = ( + common_layers.cumsum(mask_2, axis=1, exclusive=True) + mask_1_count) + position_in_expert_2 *= mask_2 + mask_2 *= tf.to_float(tf.less(position_in_expert_2, expert_capacity_f)) + mask_2_count = tf.reduce_sum(mask_2, axis=1, keep_dims=True) + mask_2_flat = tf.reduce_sum(mask_2, axis=2) + position_in_expert_2 = tf.reduce_sum(position_in_expert_2, axis=2) + gate_2 *= mask_2_flat + + # What fraction didn't fit - show summaries + miss_rate_1 = 1.0 - tf.reduce_sum(mask_1_count) / tf.to_float(batch * length) + miss_rate_2 = 1.0 - tf.reduce_sum(mask_2_count) / tf.to_float(batch * length) + tf.summary.scalar("miss_rate_1", miss_rate_1) + tf.summary.scalar("miss_rate_2", miss_rate_2) + + # renormalize the two gate values to add up to 1 + denom = gate_1 + gate_2 + 1e-9 + gate_1 /= denom + gate_2 /= denom + + # inputs: [batch, length, input_size] + # forward_assignment: [batch, length, num_experts * expert_capacity] + # expert_inputs: [batch, num_experts * expert_capacity, input_size] + + segment_ids_forward_1 = ( + (index_1 * expert_capacity) + + tf.to_int32(position_in_expert_1) + + tf.to_int32(1.0 - mask_1_flat) * (num_experts * expert_capacity)) + + segment_ids_forward_2 = ( + (index_2 * expert_capacity) + + tf.to_int32(position_in_expert_2) + + tf.to_int32(1.0 - mask_2_flat) * (num_experts * expert_capacity)) + + # Gather and scatter are painfully slow on TPU. + # We will use one_hot and matmul instead. + + # [batch, length, num_experts * expert_capacity] + one_hot_1 = tf.one_hot( + segment_ids_forward_1, num_experts * expert_capacity, dtype=inputs.dtype) + one_hot_2 = tf.one_hot( + segment_ids_forward_2, num_experts * expert_capacity, dtype=inputs.dtype) + + forward_assignment = (one_hot_1 + one_hot_2) + + # [batch, num_experts * expert_capacity, input_size] + expert_inputs = tf.matmul(forward_assignment, inputs, transpose_a=True) + + # [batch, num_experts, expert_capacity, input_size] + expert_inputs = tf.reshape( + expert_inputs, [batch, num_experts, expert_capacity, input_size]) + # [num_experts, batch, expert_capacity, input_size] + expert_inputs = tf.transpose(expert_inputs, [1, 0, 2, 3]) + + # [num_experts, batch * expert_capacity, input_size] + expert_inputs = tf.reshape( + expert_inputs, [num_experts, batch * expert_capacity, input_size]) + + # Now feed the expert inputs through the experts. + h = common_layers.batch_dense( + expert_inputs, hidden_size, activation=tf.nn.relu, name="x0") + expert_output = common_layers.batch_dense(h, output_size, name="x1") + expert_output = tf.reshape( + expert_output, [num_experts, batch, expert_capacity, output_size]) + + # [batch, num_experts, expert_capacity, output_size] + expert_output = tf.transpose(expert_output, [1, 0, 2, 3]) + expert_output = tf.reshape( + expert_output, [batch, num_experts * expert_capacity, output_size]) + + # Again, use matmul instead of unsorted_segment_sum. This time, we need + # to multiply by the combination weights gate_1 and gate_2. + + # expert_output: [batch, num_experts * expert_capacity, output_size] + # backward_assigmnent: [batch, length, num_experts * expert_capacity] + # output: [batch, length, output_size] + backward_assigmnent = ( + one_hot_1 * tf.cast(tf.expand_dims(gate_1, 2), inputs.dtype) + + one_hot_2 * tf.cast(tf.expand_dims(gate_2, 2), inputs.dtype)) + output = tf.matmul(backward_assigmnent, expert_output) + + # Compute a loss equal to the coefficient ov variation of the + # total gate value per expert per sequence. + # This loss causes the experts to be used about equally used per sequence. + importance = tf.reduce_sum(gates * (mask_1 + mask_2), 1) + loss = loss_coef * cv_squared(importance) + return output, loss + + +def reduce_by_device(parallelism, data, reduce_fn): + """Reduces data per device. + + This can be useful, for example, if we want to all-reduce n tensors on k