-
Notifications
You must be signed in to change notification settings - Fork 45.5k
Add unit test, official flags, and benchmark logs for recommendation model #4343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woo, looking so real.
def define_data_download_flags(): | ||
"""Add flags specifying data download arguments.""" | ||
flags.DEFINE_string( | ||
name="data_dir", short_name="dd", default="/tmp/movielens-data/", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robieta , what did we decide-- no short names for flags defined within individual modules? I think that will make us happier in the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove the short_name for now. We can always add it back if necessary. :)
enum_values=["ml-1m", "ml-20m"], case_sensitive=False, | ||
help=flags_core.help_wrap( | ||
"Dataset to be trained and evaluated. Two datasets are available " | ||
"for now: ml-1m and ml-20m.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We don't have any plans to add others, so no need to specify "for now"
official/recommendation/ncf_main.py
Outdated
|
||
# Return estimator's last checkpoint as global_step for estimator | ||
def _get_global_step(estimator): | ||
return int(estimator.latest_checkpoint().split("-")[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. This is fragile and unreliable. Can you try estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)? (Sorry, that's from memory, so may not be exactly the right syntax.) Or actually reading in the checkpoint and grabbing the global step?
official/recommendation/ncf_main.py
Outdated
|
||
flags.DEFINE_float( | ||
name="mf_regularization", short_name="mr", default=0.0, | ||
help=flags_core.help_wrap("The Regularization for MF embeddings.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Regularization what? Factor? Feels like it's missing a noun. Also, maybe some advice on reasonable values, defaults, etc?
official/recommendation/ncf_main.py
Outdated
enum_values=["ml-1m", "ml-20m"], case_sensitive=False, | ||
help=flags_core.help_wrap( | ||
"Dataset to be trained and evaluated. Two datasets are " | ||
"available for now: ml-1m and ml-20m.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a way with flags to display the choices and defaults for each of these without hard-coding them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absl will automatically display allowed values for enums, so it can be removed altogether.
official/recommendation/ncf_main.py
Outdated
"available for now: ml-1m and ml-20m.")) | ||
|
||
flags.DEFINE_integer( | ||
name="num_factors", short_name="nf", default=8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above discussion re:short names
official/recommendation/ncf_main.py
Outdated
|
||
flags.DEFINE_list( | ||
name="layers", short_name="ly", default=[64, 32, 16, 8], | ||
help=flags_core.help_wrap("The size of hidden layers for MLP.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a list? How does one enter the list? Comma-separated ints? Please define expectations for the user, display default.
official/recommendation/ncf_main.py
Outdated
"If passed, training will stop when the evaluation metric HR is " | ||
"greater than or equal to hr_threshold. For dataset ml-1m, the " | ||
"desired hr_threshold is 0.68; For dataset ml-20m, the threshold can " | ||
"be set as 0.95.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these set auto-magically? They should be, and that should be clear from the help text as well.
Thank you for the comments, Karmel, especially the regularization one, which helps me identify a model layer bug. Thanks a lot! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor nits.
help=flags_core.help_wrap( | ||
"Dataset to be trained and evaluated. Two datasets are available " | ||
"for now: ml-1m and ml-20m.")) | ||
": ml-1m and ml-20m.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no defaults
official/recommendation/ncf_main.py
Outdated
flags.DEFINE_list( | ||
name="mlp_regularization", default=["0.", "0.", "0.", "0."], | ||
help=flags_core.help_wrap( | ||
"The regularization factor for each MLP layer. See ml_regularization " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is ml_regularization? Should this be mf_regularization?
Hi All,
I have add official flags and benchmark logs for the recommendation model. A simple unit test for dataset.py is also added based on previous comments. As dataset.py needs to read csv files, I add a folder "unittest_data" for the csv files.