Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix resuming from checkpoint when using RayFSDPStrategy #43594

Merged
merged 2 commits into from Mar 4, 2024

Conversation

dabauxi
Copy link
Contributor

@dabauxi dabauxi commented Mar 1, 2024

Why are these changes needed?

Restoring from a checkpoint when using FSDP is currently flawed as the state_dict keys for each layer get modified and torch can not associate the weights to the layer name when loading. The current implementation always assumes that the layer keys in the state dict are prefixed with _forward_module. and then slices the key based on the length of the prefix.

The underlying reason why we remove the _forward_module. is unclear to me but we should check if it is prefixed before removing. This is implemented in this PR and fixes the loading of checkpoints when using RayFSDPStrategy

The following is an example of the wrong state_dict keys for a checkpoint:

# Keys in checkpoint["state_dict"]
der.embed_tokens.weight
der.embed_positions.weight
der.final_layer_norm.weight
...

Correct keys:

model.model.decoder.embed_tokens.weight
model.model.decoder.embed_positions.weight
model.model.decoder.final_layer_norm.weight
...

Signed-off-by: Paul Angerer <pangerer@canva.com>
Signed-off-by: Paul Angerer <pangerer@canva.com>
@woshiyyya
Copy link
Member

woshiyyya commented Mar 4, 2024

Hi @dabauxi , The reason why we did this is because Lightning previously have a _LightningModuleWrapperBase, which added the extra "_forward_module." prefix to the state_dict keys. Users need to manually trim it to correctly load the checkpoint. Lightning-AI/pytorch-lightning#16526

In the recent versions, lightning trimmed the prefix internally so no need to do it ourselves.

Thanks for the fix!

Copy link
Member

@woshiyyya woshiyyya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@galyna-anyscale
Copy link

@matthewdeng Please review this PR so it can be merged.

Copy link
Contributor

@matthewdeng matthewdeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, seems like this branching logic is needed to account for different versions of Lightning which do/don't have the prefix.

@matthewdeng matthewdeng merged commit b5aee36 into ray-project:master Mar 4, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants