-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[FSDP] Allow different optim_input
orders across ranks
#78599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
❌ 1 New FailuresAs of commit edef523 (more details on the Dr. CI page): Expand to see more
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
Differential Revision: [D36798482](https://our.internmc.facebook.com/intern/diff/D36798482) [ghstack-poisoned]
optim_input
across ranksoptim_input
orders across ranks
This enables the first argument to the optimizer constructor, `optim_input`, to have different orders across ranks, e.g. if the parameters in one parameter group are permuted. This requires modification to `full_optim_state_dict()`, `shard_full_optim_state_dict()`, and `scatter_full_optim_state_dict()`. The high-level algorithmic change is that the state dicts are kept as being keyed by unflattened parameter name until after sharding/unsharding and flattening/unflattening and are rekeyed to be by parameter ID according to each rank's own `optim_input` only at the end. Differential Revision: [D36798482](https://our.internmc.facebook.com/intern/diff/D36798482) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let's test to ensure that it is fixed for the usecase we're looking at. Thanks so much for the quick fix!
raise RuntimeError( | ||
"FSDP currently requires each rank to have at least the " | ||
"optimizer states needed by rank 0's optimizer but some ranks " | ||
"are missing some of those states" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will it also be useful to log the missing keys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add that at the cost of an all_gather_object()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: The error now looks like:
RuntimeError: FSDP currently requires each rank to have at least the optimizer states needed by rank 0's optimizer but some ranks are missing some of those states
Rank 1 is missing states for the parameters: [('block2.2.weight', 'block2.2.bias_module0.bias', 'block2.2.bias_module1.bias')]
Rank 2 is missing states for the parameters: [('block2.2.weight', 'block2.2.bias_module0.bias', 'block2.2.bias_module1.bias')]
Rank 3 is missing states for the parameters: [('block2.2.weight', 'block2.2.bias_module0.bias', 'block2.2.bias_module1.bias')]
This enables the first argument to the optimizer constructor, `optim_input`, to have different orders across ranks, e.g. if the parameters in one parameter group are permuted. This requires modification to `full_optim_state_dict()`, `shard_full_optim_state_dict()`, and `scatter_full_optim_state_dict()`. The high-level algorithmic change is that the state dicts are kept as being keyed by unflattened parameter name until after sharding/unsharding and flattening/unflattening and are rekeyed to be by parameter ID according to each rank's own `optim_input` only at the end. Differential Revision: [D36798482](https://our.internmc.facebook.com/intern/diff/D36798482) [ghstack-poisoned]
This enables the first argument to the optimizer constructor, `optim_input`, to have different orders across ranks, e.g. if the parameters in one parameter group are permuted. This requires modification to `full_optim_state_dict()`, `shard_full_optim_state_dict()`, and `scatter_full_optim_state_dict()`. The high-level algorithmic change is that the state dicts are kept as being keyed by unflattened parameter name until after sharding/unsharding and flattening/unflattening and are rekeyed to be by parameter ID according to each rank's own `optim_input` only at the end. Differential Revision: [D36798482](https://our.internmc.facebook.com/intern/diff/D36798482) [ghstack-poisoned]
Backward compatibility error due to broken ops seems unrelated. |
@pytorchbot merge |
Hey @awgu. |
Summary: Pull Request resolved: #78599 Approved by: https://github.com/rohan-varma Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/4615738a3d5aee256a9f1929c846dcae7d20a041 Reviewed By: rohan-varma Differential Revision: D36798482 fbshipit-source-id: de05b37db8ed41b6cf11a9fc526a0e03a98f570d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @fegin as sharded optimizer states are built on top of it
Stack from ghstack:
optim_input
orders across ranks #78599 [FSDP] Allow differentoptim_input
orders across ranksfull_optim_state_dict()
#78784 [FSDP][Docs] Fix typo infull_optim_state_dict()
This enables the first argument to the optimizer constructor,
optim_input
, to have different orders across ranks, e.g. if the parameters in one parameter group are permuted. This requires modification tofull_optim_state_dict()
,shard_full_optim_state_dict()
, andscatter_full_optim_state_dict()
.The high-level algorithmic change is that the state dicts are kept as being keyed by unflattened parameter name until after sharding/unsharding and flattening/unflattening and are rekeyed to be by parameter ID according to each rank's own
optim_input
only at the end.Because this PR adds non-parameter-specific collectives to
full_optim_state_dict()
, it adds agroup=None
argument to the method to have a process group to default to when running those common collectives.