New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for VariablePolicy to saved_model.load #54319
Add support for VariablePolicy to saved_model.load #54319
Conversation
@ccrusius Can you please review this PR ? Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! Left a few comments and I will ask the team about the API changes.
load_with_device = ( | ||
self._save_options.experimental_variable_policy | ||
._save_variable_devices() | ||
and config.get_soft_device_placement() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does soft placement matter here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make the behavior to be consistent if config.get_soft_device_placement() is disabled. An error is raised when an Op cannot be placed onto its intended device. So just to avoid sneakily fallbacking to CPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense
Additionally, could you add tests? |
Since the tests in keras/tests/ have just been removed to keras_team/keras, a unittest is added there, check PR16525. |
c908077
to
5c90c08
Compare
load_with_device = ( | ||
self._save_options.experimental_variable_policy | ||
._save_variable_devices() | ||
and config.get_soft_device_placement() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense
@@ -47,6 +49,9 @@ | |||
from tensorflow.python.framework import tensor_spec | |||
from tensorflow.python.framework import test_util | |||
from tensorflow.python.framework import versions | |||
from tensorflow.python.keras import Model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensorflow/python/keras is no longer maintained. Could you change the test so the objects inherit from the low-level model (tf.Module) instead?
8eb251c
to
ec9900c
Compare
.SAVE_VARIABLE_DEVICES | ||
), | ||
) | ||
loaded = load.load( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A test that runs subprocess raises flags in my head. I would drop the OOM part of the test, and instead validate that this variable has been placed on the GPU as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wenscarl Hi again, the tests are failing due to the build.
Can you fix the load_options
build rules by adding the :save_options
dependency? https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/BUILD#L645-L651
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
c.f., upstream PR-[54319](tensorflow#54319) and nvbug-[3460007](https://nvbugswb.nvidia.com/NvBugs5/SWBug.aspx?bugid=3460007&cmtNo=).
This PR adds python level support for VariablePolicy to saved_model.load when soft placement is disabled. A related issue is 53743 .