Skip to content

Commit

Permalink
[FSDP][optim_state_dict] Move local optimizer state to FSDP compute_d…
Browse files Browse the repository at this point in the history
…evice

Pull Request resolved: #110929

This will ensure all the tensors are on FSDP compute_device.
ghstack-source-id: 203487927
@exported-using-ghexport

Differential Revision: [D50059492](https://our.internmc.facebook.com/intern/diff/D50059492/)
  • Loading branch information
fegin committed Oct 10, 2023
1 parent f952551 commit 1cc2517
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,11 @@ def _convert_all_state_info(
None for _ in fsdp_param_info.param_indices
]
local_state = input_states[fqn].get(state_name, None)
# N.B. We need to move the state to compute_device. The reason is
# not yet clear and we need to figure out why the state may be on a
# different device.
if local_state is not None:
local_state = local_state.to(fsdp_param_info.state.compute_device)
state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state

# Restoring the scalar and non-tensor states. If the corresponding
Expand Down

0 comments on commit 1cc2517

Please sign in to comment.