From df41d2e36ec89a350025beae7dbda5cc3b6930e5 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Thu, 20 Feb 2020 16:09:21 -0800 Subject: [PATCH] Have exported Transformer models accept raw input strings instead of serialized tf.Example protos. Add support for in-graph post-decode processing. PiperOrigin-RevId: 296320440 --- mesh_tensorflow/transformer/dataset.py | 6 +- mesh_tensorflow/transformer/utils.py | 121 ++++++++++++------------- setup.py | 2 +- 3 files changed, 63 insertions(+), 66 deletions(-) diff --git a/mesh_tensorflow/transformer/dataset.py b/mesh_tensorflow/transformer/dataset.py index 26c245e5..66f70aad 100644 --- a/mesh_tensorflow/transformer/dataset.py +++ b/mesh_tensorflow/transformer/dataset.py @@ -261,8 +261,8 @@ def _encode_fn(features): # pylint: disable=missing-docstring inputs_enc = inputs_vocabulary.encode_tf(features["inputs"]) targets_enc = targets_vocabulary.encode_tf(features["targets"]) if append_eos: - inputs_enc = tf.concat([tf.to_int64(inputs_enc), [eos_id]], 0) - targets_enc = tf.concat([tf.to_int64(targets_enc), [eos_id]], 0) + inputs_enc = tf.concat([tf.cast(inputs_enc, tf.int64), [eos_id]], 0) + targets_enc = tf.concat([tf.cast(targets_enc, tf.int64), [eos_id]], 0) return {"inputs": inputs_enc, "targets": targets_enc} dataset = dataset.map( @@ -366,7 +366,7 @@ def my_fn(features): for k, v in features.items(): if v.dtype == tf.string: v = vocabulary.encode_tf(v) - v = tf.concat([tf.to_int64(v), [1]], 0) + v = tf.concat([tf.cast(v, tf.int64), [1]], 0) ret[k] = v else: tf.logging.info( diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index f9cb53d5..f5f98740 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -244,6 +244,7 @@ def variable_filter_max_size(v, max_size=1e7): @gin.configurable def tpu_estimator_model_fn(model_type, transformer_model, + vocabulary, model_dir, use_tpu, mesh_shape, @@ -269,6 +270,8 @@ def tpu_estimator_model_fn(model_type, model_type: a string. One of "bitransformer", "lm", "aligned", or "bi_teacher_student" transformer_model: a transformer.Unitransformer or transformer.Bitransformer + vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, + targets_vocabulary) tuple. Used for decoding in predict mode. model_dir: a string, directory to save the model to. use_tpu: a boolean mesh_shape: a mtf.Shape @@ -289,7 +292,7 @@ def tpu_estimator_model_fn(model_type, models tpu_summaries: a boolean, use rewrites to make summaries work on TPU. This may be slow, since it uses a host call hack. - predict_fn: an optional function, see docs for `run` for more information + predict_fn: an optional function, see docs for `run` for more information. variable_filter: controls which variables are trained. If None (default), train all trainable variables. If a string regex, train all variables that match this regex. @@ -414,8 +417,18 @@ def _feature_shape(key): mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) - inputs = lowering.export_to_tf_tensor(inputs) - outputs = lowering.export_to_tf_tensor(mtf_samples) + inputs = clean_decodes(lowering.export_to_tf_tensor(inputs)) + outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples)) + + # Detokenize in the graph if supported by the vocabulary and accelerator. + def _maybe_detokenize(ids, vocab): + if not use_tpu and hasattr(vocab, "decode_tf"): + return vocab.decode_tf(ids) + return ids + + inputs = _maybe_detokenize(inputs, inputs_vocabulary(vocabulary)) + outputs = _maybe_detokenize(outputs, targets_vocabulary(vocabulary)) + predictions = { "inputs": inputs, "outputs": outputs} @@ -832,24 +845,28 @@ def decode(estimator, Args: estimator: a TPUEstimator input_fn: function that returns a tf.Dataset - vocabulary: a mtf.transformer.vocabulary.Vocabulary + vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, + targets_vocabulary) tuple checkpoint_path: an optional string Returns: list of decoded strings """ - result_iter = estimator.predict(input_fn, - checkpoint_path=checkpoint_path) - vocab_size = targets_vocabulary(vocabulary).vocab_size + result_iter = estimator.predict( + input_fn, checkpoint_path=checkpoint_path) + + def _maybe_detokenize(value, vocab): + if isinstance(value, six.binary_type): + return value + return vocab.decode([int(x) for x in value]) + decodes = [] for i, result in enumerate(result_iter): - output_ids = clean_decodes(list(result["outputs"]), vocab_size) - output_string = targets_vocabulary(vocabulary).decode( - [int(x) for x in output_ids]) + input_string = _maybe_detokenize( + result["inputs"], inputs_vocabulary(vocabulary)) + output_string = _maybe_detokenize( + result["outputs"], targets_vocabulary(vocabulary)) decodes.append(output_string) - input_ids = clean_decodes(list(result["inputs"]), vocab_size) - input_string = targets_vocabulary(vocabulary).decode( - [int(x) for x in input_ids]) if i & (i - 1) == 0: # LOG every power of 2. tf.logging.info("decoded {}: {}".format(i, input_string)) @@ -935,8 +952,7 @@ def input_fn(params): checkpoint_step = get_step_from_checkpoint_path(checkpoint_path) decodes = decode( - estimator, input_fn, vocabulary, checkpoint_path=checkpoint_path - ) + estimator, input_fn, vocabulary, checkpoint_path=checkpoint_path) # Remove any padded examples dataset_size = len(inputs) * repeats decodes = decodes[:dataset_size] @@ -945,28 +961,22 @@ def input_fn(params): @gin.configurable -def clean_decodes(ids, vocab_size, eos_id=1): - """Stop at EOS or padding or OOV. +def clean_decodes(ids, eos_id=1, pad_id=0): + """Replaces everything after EOS with PAD. Args: - ids: a list of integers - vocab_size: an integer - eos_id: EOS id + ids: a Tensor of type int. + eos_id: int, EOS id. + pad_id: int, PAD id. Returns: - a list of integers + a Tensor of type int of ids. """ - ret = [] - for i in ids: - if i == eos_id: - break - if i >= vocab_size: - break - ret.append(int(i)) - return ret + eos_and_after = tf.cumsum(tf.cast(tf.equal(ids, eos_id), tf.int32), axis=1) + return tf.where_v2(tf.greater(eos_and_after, 1), pad_id, ids) -def get_estimator(model_type, input_vocab_size, output_vocab_size, mesh_shape, +def get_estimator(model_type, vocabulary, mesh_shape, layout_rules, model_dir, batch_size, sequence_length, autostack, learning_rate_schedule, keep_checkpoint_max, save_checkpoints_steps, optimizer, predict_fn, @@ -978,8 +988,8 @@ def get_estimator(model_type, input_vocab_size, output_vocab_size, mesh_shape, Args: model_type: a string - either "bitransformer", "bi_student_teacher", lm" or "aligned" - input_vocab_size: an integer, size of the input vocabulary. - output_vocab_size: an integer, size of the output vocabulary. + vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, + targets_vocabulary) tuple mesh_shape: a function passed in through gin that returns a mtf.Shape layout_rules: an input to mtf.convert_to_layout_rules() model_dir: a string, model directory path. @@ -1033,14 +1043,15 @@ def get_estimator(model_type, input_vocab_size, output_vocab_size, mesh_shape, transformer_model = build_model( model_type=model_type, - input_vocab_size=input_vocab_size, - output_vocab_size=output_vocab_size, + input_vocab_size=inputs_vocabulary(vocabulary).vocab_size, + output_vocab_size=targets_vocabulary(vocabulary).vocab_size, layout_rules=layout_rules, mesh_shape=mesh_shape) model_fn = tpu_estimator_model_fn( model_type=model_type, transformer_model=transformer_model, + vocabulary=vocabulary, model_dir=model_dir, use_tpu=use_tpu, mesh_shape=mesh_shape, @@ -1317,7 +1328,7 @@ def input_fn(params): def export_model(estimator, export_dir, vocabulary, sequence_length, - checkpoint_path=None): + batch_size=1, checkpoint_path=None): """Export a model in TF SavedModel format to be used for inference on CPUs. Args: @@ -1328,6 +1339,7 @@ def export_model(estimator, export_dir, vocabulary, sequence_length, vocabulary: sentencepiece vocab, vocabulary instance to use for encoding. sequence_length: an integer or a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} + batch_size: int, number of sequences per batch. Should match estimator. checkpoint_path: str, path to checkpoint. If None (default), use the most recent in the model directory. @@ -1338,32 +1350,20 @@ def export_model(estimator, export_dir, vocabulary, sequence_length, def serving_input_fn(): """Constructs input portion of Graph in serving. - Input is a batch of a single serialized tf.Example proto. + Input is a batch of strings. Returns: a ServingInputReceiver """ - serialized_example = tf.placeholder( + inputs = tf.placeholder( dtype=tf.string, shape=[None], - name="serialized_example") + name="inputs") - def parse_example(serialized_example): - """Function to parse serialized example with default features.""" - # For text2text models, "inputs" provides conditioning text. ("targets" - # is only used for train and eval). - # - # For text2self models, "inputs" provide partial sequences that are used - # to generate outputs. - example_specs = { - "inputs": tfds.features.Text().get_serialized_info(), - } + padded_inputs = tf.pad(inputs, [(0, tf.mod(-tf.size(inputs), batch_size))]) - parser = tfds.core.example_parser.ExampleParser(example_specs) - return parser.parse_example(serialized_example) - - dataset = tf.data.Dataset.from_tensor_slices(serialized_example) - dataset = dataset.map(parse_example) + dataset = tf.data.Dataset.from_tensor_slices(padded_inputs) + dataset = dataset.map(lambda x: {"inputs": x}) dataset = transformer_dataset.encode_all_features(dataset, vocabulary) dataset = transformer_dataset.pack_or_pad( dataset=dataset, @@ -1372,20 +1372,17 @@ def parse_example(serialized_example): feature_keys=["inputs"] ) - dataset = dataset.padded_batch( - tf.shape(serialized_example, out_type=tf.int64)[0], - dataset.output_shapes) + dataset = dataset.batch(batch_size) features = tf.data.experimental.get_single_element(dataset) - return tf.estimator.export.ServingInputReceiver( - features=features, receiver_tensors=serialized_example) + features=features, receiver_tensors=inputs) tpu_estimator.export_estimator_savedmodel( estimator=estimator, export_dir_base=export_dir, serving_input_receiver_fn=serving_input_fn, - as_text=True, + as_text=False, checkpoint_path=checkpoint_path, ) @@ -1706,8 +1703,7 @@ def run(tpu_job_name, estimator = get_estimator( model_type=model_type, - input_vocab_size=inputs_vocabulary(vocabulary).vocab_size, - output_vocab_size=targets_vocabulary(vocabulary).vocab_size, + vocabulary=vocabulary, layout_rules=layout_rules, mesh_shape=mesh_shape, model_dir=model_dir, @@ -1772,7 +1768,8 @@ def _input_fn(params, eval_dataset): model_dir, eval_checkpoint_step) elif mode == "export": - export_model(estimator, export_path, vocabulary, sequence_length) + export_model(estimator, export_path, vocabulary, sequence_length, + batch_size) else: raise ValueError( diff --git a/setup.py b/setup.py index abce0074..92975b74 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='mesh-tensorflow', - version='0.1.9', + version='0.1.10', description='Mesh TensorFlow', author='Google Inc.', author_email='no-reply@google.com',