-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Make saving model more easier when using HvdAllToAllEmbedding by adding save function overwriting patch in tf_save_restore_patch.py. #362
Conversation
try: | ||
import horovod.tensorflow as hvd | ||
try: | ||
hvd.rank() | ||
except: | ||
hvd = None | ||
except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try:
import horovod.tensorflow as hvd
hvd.rank()
except:
hvd = None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
model, | ||
filepath, | ||
overwrite, | ||
include_optimizer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to add comments for each important input arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
*args, | ||
**kwargs) | ||
|
||
def _traverse_emb_layers_and_save(hvd_rank): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have adequate UT cases to cover this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -106,6 +107,10 @@ def __init__(self, root_rank=0, device='', local_variables=None): | |||
self.register_local_var(var) | |||
|
|||
|
|||
@deprecated( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the warning always triggered? It's recommended to show it only when the users actually refer to the AllToAllEmbedding
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only show when this class was called init or new.
This callback class was only designed for horovod all2all embedding saving. For now, it's useless after new saving patch function.
def deprecated_wrapper(func_or_class):
"""Deprecation wrapper."""
if isinstance(func_or_class, type):
# If a class is deprecated, you actually want to wrap the constructor.
cls = func_or_class
if cls.__new__ is object.__new__:
# If a class defaults to its parent's constructor, wrap that instead.
func = cls.__init__
constructor_name = '__init__'
decorators, _ = tf_decorator.unwrap(func)
for decorator in decorators:
if decorator.decorator_name == 'deprecated':
# If the parent is already deprecated, there's nothing to do.
return cls
else:
func = cls.__new__
constructor_name = '__new__'
else:
cls = None
constructor_name = None
func = func_or_class
…by adding save function overwriting patch in tf_save_restore_patch.py. Also fix some import bug in tf_save_restore_patch.py. Also adding a save and restore test for HvdAllToAllEmbeeding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Make saving model more easier when using HvdAllToAllEmbedding by adding save function overwriting patch in tf_save_restore_patch.py.
Also fix some import bug in tf_save_restore_patch.py.
Also fix the example in demo where the python code for keras horovod synchronous training was wrong.
Description
I have overwritten the keras save function and now it is not necessary to save the embedding shard explicitly, as long as model.save or Keras.model.save_model is called on each rank, but tf.saved_model.save is not supported.
tf.saved_model.save can also be supported in theory, but because the obj object of the save is not necessarily the keras object, I am lazy to write it for the moment, and there is a need to talk about it.
Type of change
Checklist:
How Has This Been Tested?
Adding a test with HvdAllToAllEmbedding.
Follow the demo demo/dynamic_embedding/movielens-1m-keras-with-horovod.