Skip to content

Commit

Permalink
Adds check for FSPDOptimizer wrapper (#788)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #788

DCP expects the entire optimizer to be initialized in the state dict before calling load. Unfortunately, since optimizers are sometimes lazy loaded, this means dcp_saver has to do a check for optimizer objects and ensure they are properly initialized.

This diff adds a check for optimizers which are hidden inside wrappers. An alternative to this would be to implement in the FSDP wrapper in under utils/prepare_module, but I found that change to be higher risk, and then we'd also have the change in both places.
ghstack-source-id: 222855240
exported-using-ghexport

Reviewed By: JKSenthil

Differential Revision: D56075363

fbshipit-source-id: 68a4086ce322d9453bdd7954d56a455163d7189e
  • Loading branch information
LucasLLC authored and facebook-github-bot committed Apr 17, 2024
1 parent 34b04ae commit c444003
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,12 @@ def restore(

# necessary for loading optimizers since states are initialized lazy
for obj in app_state.values():
if isinstance(obj, torch.optim.Optimizer):
_init_optim_state(obj)
# sometimes optimizers are actually held in a wrapper which handles calling
# state_dict and load_state_dict, sa is the case for
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
optimizer = getattr(obj, "optimizer", obj)
if isinstance(optimizer, torch.optim.Optimizer):
_init_optim_state(optimizer)

dcp.load(
{"app_state": MultiStateful(app_state)},
Expand Down

0 comments on commit c444003

Please sign in to comment.