-
Notifications
You must be signed in to change notification settings - Fork 45.5k
Add official flag-parsing and benchmarking logging utils to Transformer #4163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ddd393c
fc8d589
3fe53dd
6be42aa
b0dcfc2
eb776ab
7c788a1
fd79d91
c8e7a16
612bf13
6837808
68e9ba9
a373828
2d8a94b
2faecc4
1c74943
c55af3c
426b480
f19c8a9
a06c8b1
f6b631b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,19 +18,20 @@ | |
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import os | ||
import random | ||
import sys | ||
import tarfile | ||
import urllib | ||
|
||
# pylint: disable=g-bad-import-order | ||
import six | ||
from six.moves import urllib | ||
from absl import app as absl_app | ||
from absl import flags | ||
import tensorflow as tf | ||
# pylint: enable=g-bad-import-order | ||
|
||
from official.transformer.utils import tokenizer | ||
from official.utils.flags import core as flags_core | ||
|
||
# Data sources for training/evaluating the transformer translation model. | ||
# If any of the training sources are changed, then either: | ||
|
@@ -156,7 +157,7 @@ def download_from_url(path, url): | |
filename = os.path.join(path, filename) | ||
tf.logging.info("Downloading from %s to %s." % (url, filename)) | ||
inprogress_filepath = filename + ".incomplete" | ||
inprogress_filepath, _ = urllib.urlretrieve( | ||
inprogress_filepath, _ = urllib.request.urlretrieve( | ||
url, inprogress_filepath, reporthook=download_report_hook) | ||
# Print newline to clear the carriage return from the download progress. | ||
print() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's grep and replace all the print() into tf.logging.info There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The download progress rewrites the line to update the progress. I don't think it works well with tf.logging.info. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, sorry for missing that. |
||
|
@@ -302,7 +303,7 @@ def encode_and_save_files( | |
for tmp_name, final_name in zip(tmp_filepaths, filepaths): | ||
tf.gfile.Rename(tmp_name, final_name) | ||
|
||
tf.logging.info("Saved %d Examples", counter) | ||
tf.logging.info("Saved %d Examples", counter + 1) | ||
return filepaths | ||
|
||
|
||
|
@@ -363,8 +364,6 @@ def make_dir(path): | |
|
||
def main(unused_argv): | ||
"""Obtain training and evaluation data for the Transformer model.""" | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
|
||
make_dir(FLAGS.raw_dir) | ||
make_dir(FLAGS.data_dir) | ||
|
||
|
@@ -398,22 +397,25 @@ def main(unused_argv): | |
shuffle_records(fname) | ||
|
||
|
||
def define_data_download_flags(): | ||
"""Add flags specifying data download arguments.""" | ||
flags.DEFINE_string( | ||
name="data_dir", short_name="dd", default="/tmp/translate_ende", | ||
help=flags_core.help_wrap( | ||
"Directory for where the translate_ende_wmt32k dataset is saved.")) | ||
flags.DEFINE_string( | ||
name="raw_dir", short_name="rd", default="/tmp/translate_ende_raw", | ||
help=flags_core.help_wrap( | ||
"Path where the raw data will be downloaded and extracted.")) | ||
flags.DEFINE_bool( | ||
name="search", default=False, | ||
help=flags_core.help_wrap( | ||
"If set, use binary search to find the vocabulary set with size" | ||
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE)) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--data_dir", "-dd", type=str, default="/tmp/translate_ende", | ||
help="[default: %(default)s] Directory for where the " | ||
"translate_ende_wmt32k dataset is saved.", | ||
metavar="<DD>") | ||
parser.add_argument( | ||
"--raw_dir", "-rd", type=str, default="/tmp/translate_ende_raw", | ||
help="[default: %(default)s] Path where the raw data will be downloaded " | ||
"and extracted.", | ||
metavar="<RD>") | ||
parser.add_argument( | ||
"--search", action="store_true", | ||
help="If set, use binary search to find the vocabulary set with size" | ||
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE) | ||
|
||
FLAGS, unparsed = parser.parse_known_args() | ||
main(sys.argv) | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
define_data_download_flags() | ||
FLAGS = flags.FLAGS | ||
absl_app.run(main) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absl will enforce constraints on flags. For instance after defining the flags, the validator code looks like:
And a similar for your other checks. (There is also a single flag validator.)