-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][optim_state_dict][3/N] Support use_orig_param optim_state_dict (non-broadcast version) #89900
Conversation
… (non-broadcast version) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89900
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit dc15a9c: The following jobs failed but were likely due to broken trunk (merge base 41c3b41):
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… (non-broadcast version) ghstack-source-id: abf34819697920be5c92caac648f28c71368b499 Pull Request resolved: #89900
…_state_dict (non-broadcast version)" [ghstack-poisoned]
… (non-broadcast version) ghstack-source-id: b419bba2391fc3d589c69fa16f67ac95c846c450 Pull Request resolved: #89900
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
… (non-broadcast version) ghstack-source-id: 5b1738e2fe4078281c7820cffad0ffd66c8a4988 Pull Request resolved: #89900
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
… (non-broadcast version) ghstack-source-id: b419bba2391fc3d589c69fa16f67ac95c846c450 Pull Request resolved: #89900
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
… (non-broadcast version) ghstack-source-id: 28a4888d06fb4d378e4403264fe6bfcc0ea9a8c9 Pull Request resolved: #89900
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big thanks for working on this and getting through the crazy Cartesian product of code paths!
I made an initial pass and will continue to revisit.
FullyShardedDataParallel._warn_optim_input(optim_input) | ||
using_optim_input = FullyShardedDataParallel._is_using_optim_input( | ||
optim_input, | ||
optim, | ||
) | ||
use_orig_params: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: It looks like we assume that use_orig_params
is uniform for a given FSDP tree. We should check for that in _lazy_init()
and raise an error as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I think we should just enforce this, I can't see a use case where we'd want to support mixing and matching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @rohan-varma We were talking about which FSDP constructor args we assume to be uniform. This is one place evidencing that use_orig_params
must be uniform (even if we already intuitively thought that already).
value = orig_state[state_name] | ||
if not isinstance(value, list) or not torch.is_tensor(value[0]): | ||
continue | ||
value = torch.concat(value)[: flat_param._numels[param_idx]].reshape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conceptual: In what case would torch.concat(value)
have more numel than flat_param._numels[param_idx]
(meaning that we are truncating)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent point. The padding shouldn't be visible here. However, I'm not sure if it is always true as we were discussing the cost of F.pad and may change the implementation. This is just a safety guard. If you believe padding will never be seen, I can remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good catch. Let us keep the trimming logic here. The safe guard is good.
param_id_to_param: List[nn.Parameter], | ||
param_to_fqns: Dict[nn.Parameter, List[str]], | ||
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo], | ||
merge_key: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need merge_key=True
for use_orig_params=True
to handle the fact that ranks who do not have any part of an original parameter will not have any optimizer state for that parameter?
Then, we also have merge_key=False
to preserve the existing behavior for use_orig_params=False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct.
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
… (non-broadcast version) ghstack-source-id: 6ccdeaf17cba17bc5f7372d04081e8c26ee954e0 Pull Request resolved: #89900
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
… (non-broadcast version) ghstack-source-id: bb9a41a63f5104b5f78796628e3dc3dc0494b6f6 Pull Request resolved: #89900
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me! Hopefully, you can get another review from someone else as well, but it is not strictly necessary.
I left some minor comments.
value = orig_state[state_name] | ||
if not isinstance(value, list) or not torch.is_tensor(value[0]): | ||
continue | ||
value = torch.concat(value)[: flat_param._numels[param_idx]].reshape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good catch. Let us keep the trimming logic here. The safe guard is good.
FullyShardedDataParallel._warn_optim_input(optim_input) | ||
using_optim_input = FullyShardedDataParallel._is_using_optim_input( | ||
optim_input, | ||
optim, | ||
) | ||
use_orig_params: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @rohan-varma We were talking about which FSDP constructor args we assume to be uniform. This is one place evidencing that use_orig_params
must be uniform (even if we already intuitively thought that already).
|
||
|
||
def _shard_orig_param_state( | ||
fqn: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: For this function and _gather_orig_param_state
, can we specify if fqn
is the local FQN (as stored in flat_param._fqns
) or the global FQN (which may require prepending a prefix starting from the local FSDP root).
Ugh, we might need to clarify our terminology at some point and just write it down somewhere.
object_list: List[Dict[str, Any]] = [ | ||
{} for _ in range(cast(int, fsdp_state.world_size)) | ||
] | ||
dist.all_gather_object(object_list, state_objects) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forget; for the use_orig_params=False
case, what collective do we use? all_gather_into_tensor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, will let @awgu to accept it as well, thanks!
self._test_load_optim_state( | ||
_ModelClass.NESTED, | ||
use_multiple_param_groups=False, | ||
halve_world_size=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob q: are we interested in testing other configs such as halve_world_size=True
or any of the other flags?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More tests will be introduced after we make FSDP.optim_state_dict
support all use cases.
is_fsdp_managed = isinstance(param, FlatParameter) | ||
if is_fsdp_managed: | ||
assert fqns[0] in fqn_to_fsdp_param_info | ||
is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if param is not a FlatParameter, can this ever be True? If not, then can we just omit this line, because if it is a FlatParameter, we've already asserted on this being true above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for use_orig_params
case, this line is required.
) -> Dict[str, Any]: | ||
""" | ||
Gather the optimizer state for the original parameter with the name ``fqn``. | ||
This API should only be used when ``use_orig_params`` is True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be valuable to check this with an assert?
torch.cuda.device_count(), | ||
cast(dist.ProcessGroup, fsdp_state.process_group), | ||
) | ||
value = value.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious: do we always CPU offload?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we do in the original code path. This is the inconsistency between state_dict
and optim_state_dict
. I had a PR to fix this. But it never land. Will introduce this later when I complete optim_state_dict()
.
FullyShardedDataParallel._warn_optim_input(optim_input) | ||
using_optim_input = FullyShardedDataParallel._is_using_optim_input( | ||
optim_input, | ||
optim, | ||
) | ||
use_orig_params: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I think we should just enforce this, I can't see a use case where we'd want to support mixing and matching.
you know why this API exists and how this API works. | ||
|
||
Returns the optimizer state. The state will be sharded or consolidated | ||
based on ``state_dict_type`` set by :meth:`set_state_dict_type` or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we couple model and optimizer state dict types for now, and is there interest in changing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have not enforced this yet. But it will be enforced when optim_state_dict
support all use cases.
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
… (non-broadcast version) ghstack-source-id: 43777a0dd75d4d26dabb67552dabb1b12c5f03aa Pull Request resolved: #89900
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
…_state_dict (non-broadcast version)" **What:** This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. [ghstack-poisoned]
@pytorchbot merge -f "The failing test is unrelated." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
What:
This PR add the optim state_dict support of
use_orig_params
with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern.