-
Notifications
You must be signed in to change notification settings - Fork 45.5k
Add official flag-parsing and benchmarking logging utils to Transformer #4163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
FLAGS, unparsed = parser.parse_known_args() | ||
main(sys.argv) | ||
tf.app.run() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a standalone library that download and process the data, tf.app.run() does not do much here. If we want to use all the tf lib here, probably we should change the flag parse part also to absl flags, so that its consistent across the module.
]) | ||
bleu_writer.add_summary(summary, global_step) | ||
bleu_writer.flush() | ||
benchmark_logger.log_metric( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between the eval_results and the value here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The results from estimator.evaluate() are based on an approximate translation (long story short, approx. translations are heavily based on the target golden values provided).
The function evaluate_and_log_bleu
uses the estimator.predict() path to compute the translations, where the golden values are not provided. The translations are compared to reference file to get the actual bleu score.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be added in a comment?
raise ValueError("Vocabulary file %s does not exist" % vocab_file_path) | ||
|
||
|
||
def main(_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can follow taylor's change in imagenet, which creates a wrapper of main function, so that flag object can be injected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed.
batch_size=params.batch_size # for ExamplesPerSecondHook | ||
) | ||
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir) | ||
benchmark_logger.log_run_info("transformer") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface is updated with some extra params, please rebase and add value if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oooh +1 to saving the params. Very nice
* `--steps_between_evals`: Number of training steps to run between evaluations. | ||
|
||
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_eval=1000`. | ||
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_evals=1000`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absl will enforce constraints on flags. For instance after defining the flags, the validator code looks like:
msg = "--train_steps and --train_epochs were set. Only one may be defined."
@flags.multi_flags_validator(["train_epochs", "train_steps"], message=msg)
def _check_train_limits(flag_dict):
return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None
And a similar for your other checks. (There is also a single flag validator.)
official/transformer/compute_bleu.py
Outdated
|
||
FLAGS, unparsed = parser.parse_known_args() | ||
main(sys.argv) | ||
tf.app.run() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this file using:
- argparse (Just for consistency)
- tf.app.run instead of absl.app.run (tf.app.run silently swallows typos.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to use absl
|
||
FLAGS, unparsed = parser.parse_known_args() | ||
main(sys.argv) | ||
tf.app.run() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto on argparse and tf.app.run.
|
||
# Print out training schedule | ||
# Print details of training schedule. | ||
print("Training schedule:") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for consistency with the rest of official these prints should probably be tf.logging.info's
def define_transformer_flags(): | ||
"""Add flags for running transformer_main.""" | ||
# Add common flags (data_dir, model_dir, train_epochs, etc.). | ||
flags_core.define_base(multi_gpu=False, export_dir=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you also want num_gpu=False.
train_epochs=None) | ||
|
||
|
||
def validate_flags(flags_obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As noted above this can go away.
if not tf.gfile.Exists(FLAGS.bleu_ref): | ||
raise ValueError("BLEU source file %s does not exist" % FLAGS.bleu_ref) | ||
# Define parameters based on flags | ||
if flags_obj.params == "base": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be clearer as a global dict
PARAMS_MAP = {
"base": model_params.TransformerBaseParams,
"big": model_params.TransformerBigParams,
}
raise ValueError("BLEU source file %s does not exist" % FLAGS.bleu_source) | ||
if not tf.gfile.Exists(FLAGS.bleu_ref): | ||
raise ValueError("BLEU source file %s does not exist" % FLAGS.bleu_ref) | ||
# Define parameters based on flags |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like this pattern of packaging hyperparameters. Very clean.
params.epochs_between_evals = flags_obj.epochs_between_evals | ||
params.repeat_dataset = single_iteration_train_epochs | ||
|
||
if flags_obj.batch_size is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
params.batch_size = flags_obj.batch_size or params.batch_size
? (Since we don't need to respect 0 as a legitimate batch size)
estimator = tf.estimator.Estimator( | ||
model_fn=model_fn, model_dir=FLAGS.model_dir, params=params) | ||
model_fn=model_fn, model_dir=flags_obj.model_dir, params=params) | ||
train_schedule( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as these are already nice and broken out, may I request that they be kwargs?
return tf.gfile.Exists(flags_dict["bleu_source"]) and ( | ||
tf.gfile.Exists(flags_dict["bleu_ref"])) and ( | ||
tf.gfile.Exists(vocab_file_path)) | ||
return ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does anyone know how to make this look better? It is difficult to please the lint overlords.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return all([
tf.gfile.Exists(flags_dict["bleu_source"]),
tf.gfile.Exists(flags_dict["bleu_ref"]),
tf.gfile.Exists(vocab_file_path),
])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should start thinking about the checklist for official models...
- Are all base flags enabled?
- Is benchmarking enabled?
- Are all file references Gfile?
- Is a savedmodel exported?
official/transformer/compute_bleu.py
Outdated
FLAGS, unparsed = parser.parse_known_args() | ||
main(sys.argv) | ||
flags.DEFINE_string( | ||
name="translation", short_name="t", default=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The collision of abbreviations is inevitable at this point. Maybe we should make a rule that we don't have abbreviations for flags defined within model modules. Else, we might add a flag to the base set, conflict, and not realize it until someone tries to run an inheriting model and it either errors out at arg load time, or treats flags in strange ways. Thoughts, @robieta ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think at this point one letter abbreviations are too dangerous to add, because they may collide not only with our own flags but flags defined who knows where. Two letters seems reasonably safe.
If we're doing an end-to-end test it will blow up for collisions, so I'm less worried that we will accidentally have internal collisions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an e2e test for every model? If not, should that be a requirement in our Official Model checklist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now mnist is missing it. It requires synthetic to be hermetic which I think is why mnist doesn't have it yet.
I definitely think it is something we want for every model.
]) | ||
bleu_writer.add_summary(summary, global_step) | ||
bleu_writer.flush() | ||
benchmark_logger.log_metric( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be added in a comment?
) | ||
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir) | ||
benchmark_logger.log_run_info( | ||
"transformer", "wmt_translate_ende", params.__dict__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe include kwarg names here for clarity?
@karmel Thanks for the review. I'll add the savedmodel export in a separate PR, to keep changes relatively related to flag parsing + logging. |
Bump, PTAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few minor comments around flags.
official/transformer/compute_bleu.py
Outdated
help=flags_core.help_wrap("File containing reference translation.")) | ||
flags.mark_flag_as_required("reference") | ||
|
||
flags.DEFINE_multi_enum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a both option and propagate that through? (and make this just a DEFINE_enum) with this one would have to input
--bleu_variant cased --bleu_variant uncased
, and may be confused when
--bleu_variant cased uncased
doesn't compute uncased. Unfortunately absl doesn't have a great way to define a list of enumerables.
"Specify one or more BLEU variants to calculate. Variants: \"cased\" " | ||
"or \"uncased\".")) | ||
|
||
FLAGS = flags.FLAGS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you encapsulate everything the same way you did in transformer_main?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for encapsulating everything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In part it is conceptually easier if functions are pure, and in part because we may well want to use this in a way other than calling the file directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oooh I see. Thanks for the explanation.
if FLAGS.train_epochs is None: | ||
FLAGS.train_epochs = DEFAULT_TRAIN_EPOCHS | ||
train_eval_iterations = FLAGS.train_epochs // FLAGS.epochs_between_eval | ||
if flags_obj.train_epochs is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have adopted a convention of not modifying flags objects, and instead calling getter functions to retrieve values. See official.utils.flags._performance.get_loss_scale()
as an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it makes sense not to alter the value of the flags after. It would great if these conventions can be listed (maybe as a sub-checkbox in karmel's list above).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind adding an "immutability" section to flags/README.md?
import random | ||
import sys | ||
import tarfile | ||
import urllib |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is breaking python3. Need to use six.moves.urllib
@robieta Thanks, for the comments. I edited compute_bleu, but I'm not sure I prefer this version to how it was before. Currently there isn't really a use case for calling the In general, I don't think we should require a encapsulating function for scripts like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You know what, I was wrong on the compute_belu encapsulation. You have my blessing to change it back. Sorry for making you go through the trouble.
official/transformer/compute_bleu.py
Outdated
"""Print out the BLEU scores calculated from the files defined in flags.""" | ||
if flags_obj.bleu_variant in ("both", "uncased"): | ||
score = bleu_wrapper(flags_obj.reference, flags_obj.translation, False) | ||
print("Case-insensitive results:", score) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a following change will probably needed to change all the print to tf.logging for consistency.
inprogress_filepath, _ = urllib.request.urlretrieve( | ||
url, inprogress_filepath, reporthook=download_report_hook) | ||
# Print newline to clear the carriage return from the download progress. | ||
print() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's grep and replace all the print() into tf.logging.info
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The download progress rewrites the line to update the progress. I don't think it works well with tf.logging.info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, sorry for missing that.
|
||
# Add transformer-specific flags | ||
flags.DEFINE_enum( | ||
name="params", short_name="mp", default="big", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name a parameter to "params" is quite evil since it does not provide much context. How about rename this into param_set or param_template so that its more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Evil is a bit harsh, don't you think?
I do agree that the argument name is vague. I'll change it to param_set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for being too harsh. param_set SGTM.
name="params", short_name="mp", default="big", | ||
enum_values=["base", "big"], | ||
help=flags_core.help_wrap( | ||
"Parameter set to use when creating and training the model.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its also worth mentioning all the individual param values under this umbrella as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left it out because the absl -h
option shows the possible enum values.
This is what the help text looks like:
-mp,--params: <base|big>:
Parameter set to use when creating and training the model.
(default: 'big')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What i am trying to say is param_set=base will populate other a, b and c param, which is not showing up the help txt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, I completely misread that. Yes, I think it would be good to see the individual param values. Maybe not all of them, but at least the ones that change between the big and base parameter sets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like setting params_set=big increases the default batch size, as well as the hidden_size, filter_size, and num_heads topology hyperparameters.\nSee transformer/model/model_params.py for details.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, like that. I've pushed the change that updates the help text
if FLAGS.train_epochs is None: | ||
FLAGS.train_epochs = DEFAULT_TRAIN_EPOCHS | ||
train_eval_iterations = FLAGS.train_epochs // FLAGS.epochs_between_eval | ||
if flags_obj.train_epochs is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this if-else can be combined into:
train_epochs = flags_obj.train_epochs or DEFAULT_TRAIN_EPOCHS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's much cleaner, thanks
benchmark_logger.log_run_info( | ||
model_name="transformer", | ||
dataset_name="wmt_translate_ende", | ||
run_params=params.__dict__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its better to not log all the params, since it includes noises like data_dir. Maybe a more explicit set of param is better here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking logging all the params might be useful (including data_dir) because it allows the run to be reproduced with the same dataset files as well as hyperparameters.
@qlzh727 Thanks again for all of the comments. They were very helpful. I've pushed the changes requested. |
Please wait for the comments from karmel@ if there is any. |
No description provided.