Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Removes device infos from state when saving RModules to checkpoints/states. #43906

Merged

Conversation

simonsays1980
Copy link
Collaborator

@simonsays1980 simonsays1980 commented Mar 12, 2024

Why are these changes needed?

When loading an RLModule on CPU from a checkpoint/state that was created from a replica on GPU, an error occurs. This PR fixes this error by forcing the module to save its state in form of numpy.NdArrays.

Related issue number

Closes #43905

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
@sven1977
Copy link
Contributor

I'm not sure I understand the exact need for this additional option to the API. My main argument against this would be: When stuff gets stored to a checkpoint, it should be stored in a device-independent fashion. So the issue at hand here is NOT the loading from the checkpoint, but the saving to the checkpoint beforehand, which - I'm guessing - probably happened in torch.cuda tensors, NOT in numpy format.

Can we rather take the opposite approach to keep the mental model of what a checkpoint should be clean? Always save weights (and other tensor/matrix states) as numpy arrays, never as torch or tf tensors. When loading from a checkpoint, the sequence should be something like:
0) RLModule.from_checkpoint(dir=...)

  1. load numpy arrays from dir file (pickle?)
  2. pack these numpy arrays into a state dict, mapping
  3. Call RLModule.set_state(state_dict=...) -> This automatically converts the numpy contents of state_dict into torch tensors (with self.device as the device of the already existing RLModule object), then performs a torch.nn.Module.load_state_dict() operation using these (cuda?) tensors.

@simonsays1980
Copy link
Collaborator Author

I'm not sure I understand the exact need for this additional option to the API. My main argument against this would be: When stuff gets stored to a checkpoint, it should be stored in a device-independent fashion. So the issue at hand here is NOT the loading from the checkpoint, but the saving to the checkpoint beforehand, which - I'm guessing - probably happened in torch.cuda tensors, NOT in numpy format.

Can we rather take the opposite approach to keep the mental model of what a checkpoint should be clean? Always save weights (and other tensor/matrix states) as numpy arrays, never as torch or tf tensors. When loading from a checkpoint, the sequence should be something like: 0) RLModule.from_checkpoint(dir=...)

  1. load numpy arrays from dir file (pickle?)
  2. pack these numpy arrays into a state dict, mapping
  3. Call RLModule.set_state(state_dict=...) -> This automatically converts the numpy contents of state_dict into torch tensors (with self.device as the device of the already existing RLModule object), then performs a torch.nn.Module.load_state_dict() operation using these (cuda?) tensors.

I agree with your argument that we should ensure that checkpointing is device-independent. This should be the cleanest way of doing this. We should investigate, where exactly this device-dependent checkpointing takes place and fix the problem there.

I am, however, not so sure, if 3. describes how the workflow runs right now. Here is what makes me wonder: If the torch.nn.Module.load_state_dict() would use the self.device then it should not matter, if the state contains numpy.arrays or torch.tensor (with their own device attribute) as it uses always the module's device. But no matter what - checkpointing always numpy.array will avoid this error anyways.

Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
…tore state dict now in numpy format which makes it device-agnostic.

Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
@simonsays1980 simonsays1980 changed the title Adds device placement when loading 'RModule's from checkpoints/states. Removes device infos from state when saving 'RModule's to checkpoints/states. Mar 13, 2024
@simonsays1980 simonsays1980 changed the title Removes device infos from state when saving 'RModule's to checkpoints/states. Removes device infos from state when saving RModules to checkpoints/states. Mar 13, 2024
"""Loads the module from a checkpoint directory.

Args:
checkpoint_dir_path: The directory to load the checkpoint from.
map_location: The device on which the module resides.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line?

@@ -367,6 +367,7 @@ def load_state(
modules_to_load: The modules whose state is to be loaded from the path. If
this is None, all modules that are checkpointed will be loaded into this
marl module.
map_location: The device the module resides on.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line?

@@ -117,10 +118,13 @@ def _module_state_file_name(self) -> pathlib.Path:
@override(RLModule)
def save_state(self, dir: Union[str, pathlib.Path]) -> None:
path = str(pathlib.Path(dir) / self._module_state_file_name())
torch.save(self.state_dict(), path)
torch.save(convert_to_numpy(self.state_dict()), path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! This should work.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks for this important fix @simonsays1980 !

Just two nits on the docstrings.

@sven1977 sven1977 changed the title Removes device infos from state when saving RModules to checkpoints/states. [RLlib] Removes device infos from state when saving RModules to checkpoints/states. Mar 19, 2024
Comment on lines +691 to +694
def load_state(
self,
dir: Union[str, pathlib.Path],
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def load_state(
self,
dir: Union[str, pathlib.Path],
) -> None:
def load_state(self, dir: Union[str, pathlib.Path]) -> None:

@sven1977 sven1977 merged commit 94fd80f into ray-project:master Mar 19, 2024
5 checks passed
stephanie-wang pushed a commit to stephanie-wang/ray that referenced this pull request Mar 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RLlib] - TorchRLModule cannot be loaded on CPU after training on GPU
2 participants