Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions official/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The model also applies embeddings on the input and output tokens, and adds a con
* [Training times](#training-times)
* [Evaluation results](#evaluation-results)
* [Detailed instructions](#detailed-instructions)
* [Export variables (optional)](#export-variables-optional)
* [Environment preparation](#environment-preparation)
* [Download and preprocess datasets](#download-and-preprocess-datasets)
* [Model training and evaluation](#model-training-and-evaluation)
* [Translate using the model](#translate-using-the-model)
Expand All @@ -31,46 +31,53 @@ The model also applies embeddings on the input and output tokens, and adds a con
Below are the commands for running the Transformer model. See the [Detailed instrutions](#detailed-instructions) for more details on running the model.

```
PARAMS=big
cd /path/to/models/official/transformer

# Ensure that PYTHONPATH is correctly defined as described in
# https://github.com/tensorflow/models/tree/master/official#running-the-models
# export PYTHONPATH="$PYTHONPATH:/path/to/models"

# Export variables
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS
MODEL_DIR=$HOME/transformer/model_$PARAM_SET

# Download training/evaluation datasets
python data_download.py --data_dir=$DATA_DIR

# Train the model for 10 epochs, and evaluate after every epoch.
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
--param_set=$PARAM_SET --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de

# Run during training in a separate process to get continuous updates,
# or after training is complete.
tensorboard --logdir=$MODEL_DIR

# Translate some text using the trained model
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --text="hello world"
--param_set=$PARAM_SET --text="hello world"

# Compute model's BLEU score using the newstest2014 dataset.
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
--param_set=$PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
```

## Benchmarks
### Training times

Currently, both big and base params run on a single GPU. The measurements below
Currently, both big and base parameter sets run on a single GPU. The measurements below
are reported from running the model on a P100 GPU.

Params | batches/sec | batches per epoch | time per epoch
Param Set | batches/sec | batches per epoch | time per epoch
--- | --- | --- | ---
base | 4.8 | 83244 | 4 hr
big | 1.1 | 41365 | 10 hr

### Evaluation results
Below are the case-insensitive BLEU scores after 10 epochs.

Params | Score
Param Set | Score
--- | --- |
base | 27.7
big | 28.9
Expand All @@ -79,13 +86,18 @@ big | 28.9
## Detailed instructions


0. ### Export variables (optional)
0. ### Environment preparation

#### Add models repo to PYTHONPATH
Follow the instructions described in the [Running the models](https://github.com/tensorflow/models/tree/master/official#running-the-models) section to add the models folder to the python path.

#### Export variables (optional)

Export the following variables, or modify the values in each of the snippets below:
```
PARAMS=big
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
```

1. ### Download and preprocess datasets
Expand All @@ -109,26 +121,26 @@ big | 28.9

Command to run:
```
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=$PARAM_SET
```

Arguments:
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
* `--model_dir`: Directory to save Transformer model training checkpoints.
* `--params`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* Use the `--help` or `-h` flag to get a full list of possible arguments.

#### Customizing training schedule

By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
* Training with epochs (default):
* `--train_epochs`: The total number of complete passes to make through the dataset
* `--epochs_between_eval`: The number of epochs to train between evaluations.
* `--epochs_between_evals`: The number of epochs to train between evaluations.
* Training with steps:
* `--train_steps`: sets the total number of training steps to run.
* `--steps_between_eval`: Number of training steps to run between evaluations.
* `--steps_between_evals`: Number of training steps to run between evaluations.

Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_eval=1000`.
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_evals=1000`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absl will enforce constraints on flags. For instance after defining the flags, the validator code looks like:

  msg = "--train_steps and --train_epochs were set. Only one may be defined."
  @flags.multi_flags_validator(["train_epochs", "train_steps"], message=msg)
  def _check_train_limits(flag_dict):
    return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None

And a similar for your other checks. (There is also a single flag validator.)


Note: At the beginning of each training session, the training dataset is reloaded and shuffled. Stopping the training before completing an epoch may result in worse model quality, due to the chance that some examples may be seen more than others. Therefore, it is recommended to use epochs when the model quality is important.

Expand All @@ -137,7 +149,7 @@ big | 28.9
Use these flags to compute the BLEU when the model evaluates:
* `--bleu_source`: Path to file containing text to translate.
* `--bleu_ref`: Path to file containing the reference translation.
* `--bleu_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.
* `--stop_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.

The test source and reference files located in the `test_data` directory are extracted from the preprocessed dataset from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).

Expand All @@ -155,12 +167,12 @@ big | 28.9

Command to run:
```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS --text="hello world"
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=PARAM_SET --text="hello world"
```

Arguments for initializing the Subtokenizer and trained model:
* `--data_dir`: Used to locate the vocabulary file to create a Subtokenizer, which encodes the input and decodes the model output.
* `--model_dir` and `--params`: These parameters are used to rebuild the trained model
* `--model_dir` and `--param_set`: These parameters are used to rebuild the trained model

Arguments for specifying what to translate:
* `--text`: Text to translate
Expand All @@ -170,7 +182,7 @@ big | 28.9
To translate the newstest2014 data, run:
```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
--param_set=PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
```

Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100.
Expand Down
54 changes: 31 additions & 23 deletions official/transformer/compute_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@
from __future__ import division
from __future__ import print_function

import argparse
import re
import sys
import unicodedata

# pylint: disable=g-bad-import-order
import six
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order

from official.transformer.utils import metrics
from official.utils.flags import core as flags_core


class UnicodeRegex(object):
Expand Down Expand Up @@ -99,31 +101,37 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):


def main(unused_argv):
if FLAGS.bleu_variant is None or "uncased" in FLAGS.bleu_variant:
if FLAGS.bleu_variant in ("both", "uncased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
print("Case-insensitive results:", score)
tf.logging.info("Case-insensitive results: %f" % score)

if FLAGS.bleu_variant is None or "cased" in FLAGS.bleu_variant:
if FLAGS.bleu_variant in ("both", "cased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
print("Case-sensitive results:", score)
tf.logging.info("Case-sensitive results: %f" % score)


def define_compute_bleu_flags():
"""Add flags for computing BLEU score."""
flags.DEFINE_string(
name="translation", default=None,
help=flags_core.help_wrap("File containing translated text."))
flags.mark_flag_as_required("translation")

flags.DEFINE_string(
name="reference", default=None,
help=flags_core.help_wrap("File containing reference translation."))
flags.mark_flag_as_required("reference")

flags.DEFINE_enum(
name="bleu_variant", short_name="bv", default="both",
enum_values=["both", "uncased", "cased"], case_sensitive=False,
help=flags_core.help_wrap(
"Specify one or more BLEU variants to calculate. Variants: \"cased\""
", \"uncased\", or \"both\"."))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--translation", "-t", type=str, default=None, required=True,
help="[default: %(default)s] File containing translated text.",
metavar="<T>")
parser.add_argument(
"--reference", "-r", type=str, default=None, required=True,
help="[default: %(default)s] File containing reference translation",
metavar="<R>")
parser.add_argument(
"--bleu_variant", "-bv", type=str, choices=["uncased", "cased"],
nargs="*", default=None,
help="Specify one or more BLEU variants to calculate (both are "
"calculated by default. Variants: \"cased\" or \"uncased\".",
metavar="<BV>")

FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
tf.logging.set_verbosity(tf.logging.INFO)
define_compute_bleu_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
52 changes: 27 additions & 25 deletions official/transformer/data_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
from __future__ import division
from __future__ import print_function

import argparse
import os
import random
import sys
import tarfile
import urllib

# pylint: disable=g-bad-import-order
import six
from six.moves import urllib
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order

from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core

# Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either:
Expand Down Expand Up @@ -156,7 +157,7 @@ def download_from_url(path, url):
filename = os.path.join(path, filename)
tf.logging.info("Downloading from %s to %s." % (url, filename))
inprogress_filepath = filename + ".incomplete"
inprogress_filepath, _ = urllib.urlretrieve(
inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress.
print()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's grep and replace all the print() into tf.logging.info

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The download progress rewrites the line to update the progress. I don't think it works well with tf.logging.info.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, sorry for missing that.

Expand Down Expand Up @@ -302,7 +303,7 @@ def encode_and_save_files(
for tmp_name, final_name in zip(tmp_filepaths, filepaths):
tf.gfile.Rename(tmp_name, final_name)

tf.logging.info("Saved %d Examples", counter)
tf.logging.info("Saved %d Examples", counter + 1)
return filepaths


Expand Down Expand Up @@ -363,8 +364,6 @@ def make_dir(path):

def main(unused_argv):
"""Obtain training and evaluation data for the Transformer model."""
tf.logging.set_verbosity(tf.logging.INFO)

make_dir(FLAGS.raw_dir)
make_dir(FLAGS.data_dir)

Expand Down Expand Up @@ -398,22 +397,25 @@ def main(unused_argv):
shuffle_records(fname)


def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="raw_dir", short_name="rd", default="/tmp/translate_ende_raw",
help=flags_core.help_wrap(
"Path where the raw data will be downloaded and extracted."))
flags.DEFINE_bool(
name="search", default=False,
help=flags_core.help_wrap(
"If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/translate_ende",
help="[default: %(default)s] Directory for where the "
"translate_ende_wmt32k dataset is saved.",
metavar="<DD>")
parser.add_argument(
"--raw_dir", "-rd", type=str, default="/tmp/translate_ende_raw",
help="[default: %(default)s] Path where the raw data will be downloaded "
"and extracted.",
metavar="<RD>")
parser.add_argument(
"--search", action="store_true",
help="If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE)

FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
tf.logging.set_verbosity(tf.logging.INFO)
define_data_download_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
Loading