-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Fix load_sharded_state_dict FQN mismatches for shared parameters #86524
Conversation
`_sharded_pre_load_state_dict_hook()` should calls `_param_fqns()` to ensure shared parameters names are also included. Differential Revision: [D40201304](https://our.internmc.facebook.com/intern/diff/D40201304/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86524
Note: Links to docs will display an error until the docs builds have been completed. ✅ No Failures, 1 PendingAs of commit 5764aae: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
`_sharded_pre_load_state_dict_hook()` should calls `_param_fqns()` to ensure shared parameters names are also included. Differential Revision: [D40201304](https://our.internmc.facebook.com/intern/diff/D40201304/) ghstack-source-id: 169790107 Pull Request resolved: #86524
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change makes sense to me, but if we want to land this formally, should we include a unit test?
Her are some code examples for how to use TransformerWithSharedParams.init()
, which has shared parameters (see common_fsdp.py
for the precise details).
No FSDP:
pytorch/test/distributed/fsdp/test_fsdp_use_orig_params.py
Lines 84 to 89 in be682be
model = TransformerWithSharedParams.init( | |
self.process_group, | |
FSDPInitMode.NO_FSDP, | |
CUDAInitMode.CUDA_BEFORE, | |
deterministic=True, | |
) |
FSDP:
pytorch/test/distributed/fsdp/test_fsdp_use_orig_params.py
Lines 115 to 133 in be682be
fsdp_kwargs = { | |
"auto_wrap_policy": functools.partial( | |
transformer_auto_wrap_policy, | |
transformer_layer_cls={ | |
TransformerEncoderLayer, | |
TransformerDecoderLayer, | |
}, | |
), | |
"use_orig_params": True, | |
"sharding_strategy": sharding_strategy, | |
"backward_prefetch": backward_prefetch, | |
"cpu_offload": cpu_offload, | |
} | |
model = TransformerWithSharedParams.init( | |
self.process_group, | |
FSDPInitMode.NO_FSDP, | |
cuda_init_mode, | |
deterministic=True, | |
) |
@awgu Yup, that makes sense. What is puzzling me is that I thought |
@awgu Please ignore my previous comment. I didn't find shared parameters testing in |
…ed parameters" `_sharded_pre_load_state_dict_hook()` should calls `_param_fqns()` to ensure shared parameters names are also included. Differential Revision: [D40201304](https://our.internmc.facebook.com/intern/diff/D40201304/) [ghstack-poisoned]
Pull Request resolved: #86524 `_sharded_pre_load_state_dict_hook()` should calls `_param_fqns()` to ensure shared parameters names are also included. ghstack-source-id: 169844364 Differential Revision: [D40201304](https://our.internmc.facebook.com/intern/diff/D40201304/)
) | ||
|
||
fsdp_model = model_creator() | ||
for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this portion add to the unittest?
@property | ||
def _shared_param_fqns(self) -> Iterator[Tuple[str, str, str]]: | ||
for param_name, module_name in ( | ||
self._fsdp_wrapped_module.handle.shared_parameter_module_names() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: i remembered _fsdp_wrapped_module could have multiple handles, so this should be "self._fsdp_wrapped_module.handles[0]"? @awgu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think either @fegin or I will need to rebase, but it is not a big deal either way. The preferred approach will be self._handles[0]
since I want to get rid of self._fsdp_wrapped_module
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed to self._handles[0]
as suggested.
…ed parameters" `_sharded_pre_load_state_dict_hook()` should calls `_param_fqns()` to ensure shared parameters names are also included. Differential Revision: [D40201304](https://our.internmc.facebook.com/intern/diff/D40201304/) [ghstack-poisoned]
Pull Request resolved: #86524 `_sharded_pre_load_state_dict_hook()` should calls `_param_fqns()` to ensure shared parameters names are also included. ghstack-source-id: 170184507 Differential Revision: [D40201304](https://our.internmc.facebook.com/intern/diff/D40201304/)
…ed parameters" `_sharded_pre_load_state_dict_hook()` should calls `_param_fqns()` to ensure shared parameters names are also included. Differential Revision: [D40201304](https://our.internmc.facebook.com/intern/diff/D40201304/) [ghstack-poisoned]
Pull Request resolved: #86524 `_sharded_pre_load_state_dict_hook()` should calls `_param_fqns()` to ensure shared parameters names are also included. ghstack-source-id: 170321602 Differential Revision: [D40201304](https://our.internmc.facebook.com/intern/diff/D40201304/)
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @fegin. |
Stack from ghstack (oldest at bottom):
_sharded_pre_load_state_dict_hook()
should calls_param_fqns()
to ensure shared parameters names are also included.Differential Revision: D40201304