Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][optim_state_dict][3/N] Support use_orig_param optim_state_dict (non-broadcast version) #89900

Closed
wants to merge 13 commits into from

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Nov 30, 2022

Stack from ghstack (oldest at bottom):

What:
This PR add the optim state_dict support of use_orig_params with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 30, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89900

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit dc15a9c:

The following jobs failed but were likely due to broken trunk (merge base 41c3b41):

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Nov 30, 2022
fegin added a commit that referenced this pull request Nov 30, 2022
… (non-broadcast version)

ghstack-source-id: abf34819697920be5c92caac648f28c71368b499
Pull Request resolved: #89900
…_state_dict (non-broadcast version)"

[ghstack-poisoned]
fegin added a commit that referenced this pull request Nov 30, 2022
… (non-broadcast version)

ghstack-source-id: b419bba2391fc3d589c69fa16f67ac95c846c450
Pull Request resolved: #89900
…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
fegin added a commit that referenced this pull request Nov 30, 2022
… (non-broadcast version)

ghstack-source-id: 5b1738e2fe4078281c7820cffad0ffd66c8a4988
Pull Request resolved: #89900
…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
fegin added a commit that referenced this pull request Nov 30, 2022
… (non-broadcast version)

ghstack-source-id: b419bba2391fc3d589c69fa16f67ac95c846c450
Pull Request resolved: #89900
…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
fegin added a commit that referenced this pull request Nov 30, 2022
… (non-broadcast version)

ghstack-source-id: 28a4888d06fb4d378e4403264fe6bfcc0ea9a8c9
Pull Request resolved: #89900
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big thanks for working on this and getting through the crazy Cartesian product of code paths!

I made an initial pass and will continue to revisit.

FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
optim_input,
optim,
)
use_orig_params: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: It looks like we assume that use_orig_params is uniform for a given FSDP tree. We should check for that in _lazy_init() and raise an error as needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think we should just enforce this, I can't see a use case where we'd want to support mixing and matching.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @rohan-varma We were talking about which FSDP constructor args we assume to be uniform. This is one place evidencing that use_orig_params must be uniform (even if we already intuitively thought that already).

test/distributed/fsdp/test_fsdp_optim_state.py Outdated Show resolved Hide resolved
torch/distributed/fsdp/_optim_utils.py Outdated Show resolved Hide resolved
value = orig_state[state_name]
if not isinstance(value, list) or not torch.is_tensor(value[0]):
continue
value = torch.concat(value)[: flat_param._numels[param_idx]].reshape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptual: In what case would torch.concat(value) have more numel than flat_param._numels[param_idx] (meaning that we are truncating)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent point. The padding shouldn't be visible here. However, I'm not sure if it is always true as we were discussing the cost of F.pad and may change the implementation. This is just a safety guard. If you believe padding will never be seen, I can remove it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch. Let us keep the trimming logic here. The safe guard is good.

torch/distributed/fsdp/_optim_utils.py Outdated Show resolved Hide resolved
param_id_to_param: List[nn.Parameter],
param_to_fqns: Dict[nn.Parameter, List[str]],
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
merge_key: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need merge_key=True for use_orig_params=True to handle the fact that ranks who do not have any part of an original parameter will not have any optimizer state for that parameter?

Then, we also have merge_key=False to preserve the existing behavior for use_orig_params=False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct.

…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
@fegin fegin requested a review from wanchaol as a code owner November 30, 2022 22:56
fegin added a commit that referenced this pull request Nov 30, 2022
… (non-broadcast version)

ghstack-source-id: 6ccdeaf17cba17bc5f7372d04081e8c26ee954e0
Pull Request resolved: #89900
…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! Hopefully, you can get another review from someone else as well, but it is not strictly necessary.

I left some minor comments.

torch/distributed/fsdp/_optim_utils.py Outdated Show resolved Hide resolved
torch/distributed/fsdp/_optim_utils.py Show resolved Hide resolved
torch/distributed/fsdp/_optim_utils.py Outdated Show resolved Hide resolved
value = orig_state[state_name]
if not isinstance(value, list) or not torch.is_tensor(value[0]):
continue
value = torch.concat(value)[: flat_param._numels[param_idx]].reshape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch. Let us keep the trimming logic here. The safe guard is good.

FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
optim_input,
optim,
)
use_orig_params: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @rohan-varma We were talking about which FSDP constructor args we assume to be uniform. This is one place evidencing that use_orig_params must be uniform (even if we already intuitively thought that already).

torch/distributed/fsdp/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
torch/distributed/fsdp/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved


def _shard_orig_param_state(
fqn: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For this function and _gather_orig_param_state, can we specify if fqn is the local FQN (as stored in flat_param._fqns) or the global FQN (which may require prepending a prefix starting from the local FSDP root).

Ugh, we might need to clarify our terminology at some point and just write it down somewhere.

object_list: List[Dict[str, Any]] = [
{} for _ in range(cast(int, fsdp_state.world_size))
]
dist.all_gather_object(object_list, state_objects)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget; for the use_orig_params=False case, what collective do we use? all_gather_into_tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct.

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, will let @awgu to accept it as well, thanks!

self._test_load_optim_state(
_ModelClass.NESTED,
use_multiple_param_groups=False,
halve_world_size=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q: are we interested in testing other configs such as halve_world_size=True or any of the other flags?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More tests will be introduced after we make FSDP.optim_state_dict support all use cases.

is_fsdp_managed = isinstance(param, FlatParameter)
if is_fsdp_managed:
assert fqns[0] in fqn_to_fsdp_param_info
is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if param is not a FlatParameter, can this ever be True? If not, then can we just omit this line, because if it is a FlatParameter, we've already asserted on this being true above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for use_orig_params case, this line is required.

) -> Dict[str, Any]:
"""
Gather the optimizer state for the original parameter with the name ``fqn``.
This API should only be used when ``use_orig_params`` is True.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be valuable to check this with an assert?

torch.cuda.device_count(),
cast(dist.ProcessGroup, fsdp_state.process_group),
)
value = value.cpu()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: do we always CPU offload?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do in the original code path. This is the inconsistency between state_dict and optim_state_dict. I had a PR to fix this. But it never land. Will introduce this later when I complete optim_state_dict().

FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
optim_input,
optim,
)
use_orig_params: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think we should just enforce this, I can't see a use case where we'd want to support mixing and matching.

you know why this API exists and how this API works.

Returns the optimizer state. The state will be sharded or consolidated
based on ``state_dict_type`` set by :meth:`set_state_dict_type` or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we couple model and optimizer state dict types for now, and is there interest in changing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have not enforced this yet. But it will be enforced when optim_state_dict support all use cases.

…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Dec 13, 2022

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 13, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: The following mandatory check(s) failed (Rule Distributed):

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
fegin added a commit that referenced this pull request Dec 13, 2022
… (non-broadcast version)

ghstack-source-id: 43777a0dd75d4d26dabb67552dabb1b12c5f03aa
Pull Request resolved: #89900
…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
@fegin fegin added the with-ssh label Dec 13, 2022
…_state_dict (non-broadcast version)"


**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern. 

[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Dec 13, 2022

@pytorchbot merge -f "The failing test is unrelated."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/fegin/49/head branch June 8, 2023 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category with-ssh
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants