Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mesh_tensorflow/transformer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def _encode_fn(features): # pylint: disable=missing-docstring
inputs_enc = inputs_vocabulary.encode_tf(features["inputs"])
targets_enc = targets_vocabulary.encode_tf(features["targets"])
if append_eos:
inputs_enc = tf.concat([tf.to_int64(inputs_enc), [eos_id]], 0)
targets_enc = tf.concat([tf.to_int64(targets_enc), [eos_id]], 0)
inputs_enc = tf.concat([tf.cast(inputs_enc, tf.int64), [eos_id]], 0)
targets_enc = tf.concat([tf.cast(targets_enc, tf.int64), [eos_id]], 0)
return {"inputs": inputs_enc, "targets": targets_enc}

dataset = dataset.map(
Expand Down Expand Up @@ -366,7 +366,7 @@ def my_fn(features):
for k, v in features.items():
if v.dtype == tf.string:
v = vocabulary.encode_tf(v)
v = tf.concat([tf.to_int64(v), [1]], 0)
v = tf.concat([tf.cast(v, tf.int64), [1]], 0)
ret[k] = v
else:
tf.logging.info(
Expand Down
121 changes: 59 additions & 62 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def variable_filter_max_size(v, max_size=1e7):
@gin.configurable
def tpu_estimator_model_fn(model_type,
transformer_model,
vocabulary,
model_dir,
use_tpu,
mesh_shape,
Expand All @@ -269,6 +270,8 @@ def tpu_estimator_model_fn(model_type,
model_type: a string. One of "bitransformer", "lm", "aligned", or
"bi_teacher_student"
transformer_model: a transformer.Unitransformer or transformer.Bitransformer
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple. Used for decoding in predict mode.
model_dir: a string, directory to save the model to.
use_tpu: a boolean
mesh_shape: a mtf.Shape
Expand All @@ -289,7 +292,7 @@ def tpu_estimator_model_fn(model_type,
models
tpu_summaries: a boolean, use rewrites to make summaries work on TPU. This
may be slow, since it uses a host call hack.
predict_fn: an optional function, see docs for `run` for more information
predict_fn: an optional function, see docs for `run` for more information.
variable_filter: controls which variables are trained.
If None (default), train all trainable variables.
If a string regex, train all variables that match this regex.
Expand Down Expand Up @@ -414,8 +417,18 @@ def _feature_shape(key):
mtf_samples = mtf.anonymize(mtf_samples)
inputs = mtf.anonymize(inputs)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
inputs = lowering.export_to_tf_tensor(inputs)
outputs = lowering.export_to_tf_tensor(mtf_samples)
inputs = clean_decodes(lowering.export_to_tf_tensor(inputs))
outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples))

# Detokenize in the graph if supported by the vocabulary and accelerator.
def _maybe_detokenize(ids, vocab):
if not use_tpu and hasattr(vocab, "decode_tf"):
return vocab.decode_tf(ids)
return ids

inputs = _maybe_detokenize(inputs, inputs_vocabulary(vocabulary))
outputs = _maybe_detokenize(outputs, targets_vocabulary(vocabulary))

predictions = {
"inputs": inputs,
"outputs": outputs}
Expand Down Expand Up @@ -832,24 +845,28 @@ def decode(estimator,
Args:
estimator: a TPUEstimator
input_fn: function that returns a tf.Dataset
vocabulary: a mtf.transformer.vocabulary.Vocabulary
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple
checkpoint_path: an optional string

Returns:
list of decoded strings
"""
result_iter = estimator.predict(input_fn,
checkpoint_path=checkpoint_path)
vocab_size = targets_vocabulary(vocabulary).vocab_size
result_iter = estimator.predict(
input_fn, checkpoint_path=checkpoint_path)

def _maybe_detokenize(value, vocab):
if isinstance(value, six.binary_type):
return value
return vocab.decode([int(x) for x in value])

decodes = []
for i, result in enumerate(result_iter):
output_ids = clean_decodes(list(result["outputs"]), vocab_size)
output_string = targets_vocabulary(vocabulary).decode(
[int(x) for x in output_ids])
input_string = _maybe_detokenize(
result["inputs"], inputs_vocabulary(vocabulary))
output_string = _maybe_detokenize(
result["outputs"], targets_vocabulary(vocabulary))
decodes.append(output_string)
input_ids = clean_decodes(list(result["inputs"]), vocab_size)
input_string = targets_vocabulary(vocabulary).decode(
[int(x) for x in input_ids])
if i & (i - 1) == 0:
# LOG every power of 2.
tf.logging.info("decoded {}: {}".format(i, input_string))
Expand Down Expand Up @@ -935,8 +952,7 @@ def input_fn(params):

checkpoint_step = get_step_from_checkpoint_path(checkpoint_path)
decodes = decode(
estimator, input_fn, vocabulary, checkpoint_path=checkpoint_path
)
estimator, input_fn, vocabulary, checkpoint_path=checkpoint_path)
# Remove any padded examples
dataset_size = len(inputs) * repeats
decodes = decodes[:dataset_size]
Expand All @@ -945,28 +961,22 @@ def input_fn(params):


@gin.configurable
def clean_decodes(ids, vocab_size, eos_id=1):
"""Stop at EOS or padding or OOV.
def clean_decodes(ids, eos_id=1, pad_id=0):
"""Replaces everything after EOS with PAD.

Args:
ids: a list of integers
vocab_size: an integer
eos_id: EOS id
ids: a Tensor of type int.
eos_id: int, EOS id.
pad_id: int, PAD id.

Returns:
a list of integers
a Tensor of type int of ids.
"""
ret = []
for i in ids:
if i == eos_id:
break
if i >= vocab_size:
break
ret.append(int(i))
return ret
eos_and_after = tf.cumsum(tf.cast(tf.equal(ids, eos_id), tf.int32), axis=1)
return tf.where_v2(tf.greater(eos_and_after, 1), pad_id, ids)


def get_estimator(model_type, input_vocab_size, output_vocab_size, mesh_shape,
def get_estimator(model_type, vocabulary, mesh_shape,
layout_rules, model_dir, batch_size, sequence_length,
autostack, learning_rate_schedule, keep_checkpoint_max,
save_checkpoints_steps, optimizer, predict_fn,
Expand All @@ -978,8 +988,8 @@ def get_estimator(model_type, input_vocab_size, output_vocab_size, mesh_shape,
Args:
model_type: a string - either "bitransformer", "bi_student_teacher", lm" or
"aligned"
input_vocab_size: an integer, size of the input vocabulary.
output_vocab_size: an integer, size of the output vocabulary.
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple
mesh_shape: a function passed in through gin that returns a mtf.Shape
layout_rules: an input to mtf.convert_to_layout_rules()
model_dir: a string, model directory path.
Expand Down Expand Up @@ -1033,14 +1043,15 @@ def get_estimator(model_type, input_vocab_size, output_vocab_size, mesh_shape,

transformer_model = build_model(
model_type=model_type,
input_vocab_size=input_vocab_size,
output_vocab_size=output_vocab_size,
input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
layout_rules=layout_rules,
mesh_shape=mesh_shape)

model_fn = tpu_estimator_model_fn(
model_type=model_type,
transformer_model=transformer_model,
vocabulary=vocabulary,
model_dir=model_dir,
use_tpu=use_tpu,
mesh_shape=mesh_shape,
Expand Down Expand Up @@ -1317,7 +1328,7 @@ def input_fn(params):


def export_model(estimator, export_dir, vocabulary, sequence_length,
checkpoint_path=None):
batch_size=1, checkpoint_path=None):
"""Export a model in TF SavedModel format to be used for inference on CPUs.

Args:
Expand All @@ -1328,6 +1339,7 @@ def export_model(estimator, export_dir, vocabulary, sequence_length,
vocabulary: sentencepiece vocab, vocabulary instance to use for encoding.
sequence_length: an integer or a dict from feature-key to integer
the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
batch_size: int, number of sequences per batch. Should match estimator.
checkpoint_path: str, path to checkpoint. If None (default), use the most
recent in the model directory.

Expand All @@ -1338,32 +1350,20 @@ def export_model(estimator, export_dir, vocabulary, sequence_length,
def serving_input_fn():
"""Constructs input portion of Graph in serving.

Input is a batch of a single serialized tf.Example proto.
Input is a batch of strings.

Returns:
a ServingInputReceiver
"""
serialized_example = tf.placeholder(
inputs = tf.placeholder(
dtype=tf.string,
shape=[None],
name="serialized_example")
name="inputs")

def parse_example(serialized_example):
"""Function to parse serialized example with default features."""
# For text2text models, "inputs" provides conditioning text. ("targets"
# is only used for train and eval).
#
# For text2self models, "inputs" provide partial sequences that are used
# to generate outputs.
example_specs = {
"inputs": tfds.features.Text().get_serialized_info(),
}
padded_inputs = tf.pad(inputs, [(0, tf.mod(-tf.size(inputs), batch_size))])

parser = tfds.core.example_parser.ExampleParser(example_specs)
return parser.parse_example(serialized_example)

dataset = tf.data.Dataset.from_tensor_slices(serialized_example)
dataset = dataset.map(parse_example)
dataset = tf.data.Dataset.from_tensor_slices(padded_inputs)
dataset = dataset.map(lambda x: {"inputs": x})
dataset = transformer_dataset.encode_all_features(dataset, vocabulary)
dataset = transformer_dataset.pack_or_pad(
dataset=dataset,
Expand All @@ -1372,20 +1372,17 @@ def parse_example(serialized_example):
feature_keys=["inputs"]
)

dataset = dataset.padded_batch(
tf.shape(serialized_example, out_type=tf.int64)[0],
dataset.output_shapes)
dataset = dataset.batch(batch_size)

features = tf.data.experimental.get_single_element(dataset)

return tf.estimator.export.ServingInputReceiver(
features=features, receiver_tensors=serialized_example)
features=features, receiver_tensors=inputs)

tpu_estimator.export_estimator_savedmodel(
estimator=estimator,
export_dir_base=export_dir,
serving_input_receiver_fn=serving_input_fn,
as_text=True,
as_text=False,
checkpoint_path=checkpoint_path,
)

Expand Down Expand Up @@ -1706,8 +1703,7 @@ def run(tpu_job_name,

estimator = get_estimator(
model_type=model_type,
input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
vocabulary=vocabulary,
layout_rules=layout_rules,
mesh_shape=mesh_shape,
model_dir=model_dir,
Expand Down Expand Up @@ -1772,7 +1768,8 @@ def _input_fn(params, eval_dataset):
model_dir, eval_checkpoint_step)

elif mode == "export":
export_model(estimator, export_path, vocabulary, sequence_length)
export_model(estimator, export_path, vocabulary, sequence_length,
batch_size)

else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='mesh-tensorflow',
version='0.1.9',
version='0.1.10',
description='Mesh TensorFlow',
author='Google Inc.',
author_email='no-reply@google.com',
Expand Down