-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Removes device infos from state when saving RModule
s to checkpoints/states.
#43906
[RLlib] Removes device infos from state when saving RModule
s to checkpoints/states.
#43906
Conversation
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
I'm not sure I understand the exact need for this additional option to the API. My main argument against this would be: When stuff gets stored to a checkpoint, it should be stored in a device-independent fashion. So the issue at hand here is NOT the loading from the checkpoint, but the saving to the checkpoint beforehand, which - I'm guessing - probably happened in torch.cuda tensors, NOT in numpy format. Can we rather take the opposite approach to keep the mental model of what a checkpoint should be clean? Always save weights (and other tensor/matrix states) as numpy arrays, never as torch or tf tensors. When loading from a checkpoint, the sequence should be something like:
|
I agree with your argument that we should ensure that checkpointing is device-independent. This should be the cleanest way of doing this. We should investigate, where exactly this device-dependent checkpointing takes place and fix the problem there. I am, however, not so sure, if 3. describes how the workflow runs right now. Here is what makes me wonder: If the |
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
…tore state dict now in numpy format which makes it device-agnostic. Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
RModule
s to checkpoints/states.
rllib/core/rl_module/rl_module.py
Outdated
"""Loads the module from a checkpoint directory. | ||
|
||
Args: | ||
checkpoint_dir_path: The directory to load the checkpoint from. | ||
map_location: The device on which the module resides. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this line?
rllib/core/rl_module/marl_module.py
Outdated
@@ -367,6 +367,7 @@ def load_state( | |||
modules_to_load: The modules whose state is to be loaded from the path. If | |||
this is None, all modules that are checkpointed will be loaded into this | |||
marl module. | |||
map_location: The device the module resides on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this line?
@@ -117,10 +118,13 @@ def _module_state_file_name(self) -> pathlib.Path: | |||
@override(RLModule) | |||
def save_state(self, dir: Union[str, pathlib.Path]) -> None: | |||
path = str(pathlib.Path(dir) / self._module_state_file_name()) | |||
torch.save(self.state_dict(), path) | |||
torch.save(convert_to_numpy(self.state_dict()), path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! This should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks for this important fix @simonsays1980 !
Just two nits on the docstrings.
RModule
s to checkpoints/states.RModule
s to checkpoints/states.
def load_state( | ||
self, | ||
dir: Union[str, pathlib.Path], | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def load_state( | |
self, | |
dir: Union[str, pathlib.Path], | |
) -> None: | |
def load_state(self, dir: Union[str, pathlib.Path]) -> None: |
Signed-off-by: Sven Mika <sven@anyscale.io>
Why are these changes needed?
When loading an
RLModule
on CPU from a checkpoint/state that was created from a replica on GPU, an error occurs. This PR fixes this error by forcing the module to save its state in form ofnumpy.NdArray
s.Related issue number
Closes #43905
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.