Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][2/N] Fix grad zero vs. None edge case #87308

Closed
wants to merge 6 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Oct 19, 2022

Stack from ghstack:

Some original parameters corresponding to one FlatParameter may have None gradient while others do not. In that case, the flat_param.grad must be non-None. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a _is_grad_none mask over the parameters' gradients.

  • _is_grad_none is initialized to False for all.
  • _is_grad_none[i] is set to True when writing zeros in place of None when writing back the ith gradient.
  • _is_grad_none[i] is set to False via _reset_is_grad_none(), which should be called in the post-backward. See the docstring for details.
  • _is_grad_none[i] must be False in order to set param.grad to be a view into flat_param.grad.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 19, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87308

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures, 3 Pending

As of commit 08508fa:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Oct 19, 2022
awgu added a commit that referenced this pull request Oct 19, 2022
ghstack-source-id: fd742a592feaea4b549d92ce5e025f4964cf2429
Pull Request resolved: #87308
awgu added a commit that referenced this pull request Oct 19, 2022
ghstack-source-id: 85614e658ab8a39020ed3908279ab1e80e0e4bc6
Pull Request resolved: #87308
Some original parameters corresponding to one `FlatParameter` may have `None` gradient while others do not. In that case, the `flat_param.grad` must be non-`None`. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a `_is_grad_none` mask over the parameters' gradients.
- `_is_grad_none` is initialized to `False` for all.
- `_is_grad_none[i]` is set to `True` when writing zeros in place of `None` when writing back the `i`th gradient.
- `_is_grad_none[i]` is set to `False` via `_reset_is_grad_none()`, which should be called in the post-backward. See the docstring for details.
- `_is_grad_none[i]` must be `False` in order to set `param.grad` to be a view into `flat_param.grad`.

This PR additionally changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work.


[ghstack-poisoned]
@awgu awgu changed the title [FSDP][1/N] Fix grad zero vs. None edge case [FSDP][2/N] Fix grad zero vs. None edge case Oct 19, 2022
awgu added a commit that referenced this pull request Oct 19, 2022
ghstack-source-id: 85614e658ab8a39020ed3908279ab1e80e0e4bc6
Pull Request resolved: #87308
Some original parameters corresponding to one `FlatParameter` may have `None` gradient while others do not. In that case, the `flat_param.grad` must be non-`None`. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a `_is_grad_none` mask over the parameters' gradients.
- `_is_grad_none` is initialized to `False` for all.
- `_is_grad_none[i]` is set to `True` when writing zeros in place of `None` when writing back the `i`th gradient.
- `_is_grad_none[i]` is set to `False` via `_reset_is_grad_none()`, which should be called in the post-backward. See the docstring for details.
- `_is_grad_none[i]` must be `False` in order to set `param.grad` to be a view into `flat_param.grad`.




[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 19, 2022
ghstack-source-id: 83774374607bde9dd423c0c517c00eb84c8240c2
Pull Request resolved: #87308
Some original parameters corresponding to one `FlatParameter` may have `None` gradient while others do not. In that case, the `flat_param.grad` must be non-`None`. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a `_is_grad_none` mask over the parameters' gradients.
- `_is_grad_none` is initialized to `False` for all.
- `_is_grad_none[i]` is set to `True` when writing zeros in place of `None` when writing back the `i`th gradient.
- `_is_grad_none[i]` is set to `False` via `_reset_is_grad_none()`, which should be called in the post-backward. See the docstring for details.
- `_is_grad_none[i]` must be `False` in order to set `param.grad` to be a view into `flat_param.grad`.




[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 20, 2022
ghstack-source-id: b983aa8042152bec5f233812a4ebe05135df823b
Pull Request resolved: #87308
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 21, 2022
Some original parameters corresponding to one `FlatParameter` may have `None` gradient while others do not. In that case, the `flat_param.grad` must be non-`None`. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a `_is_grad_none` mask over the parameters' gradients.
- `_is_grad_none` is initialized to `False` for all.
- `_is_grad_none[i]` is set to `True` when writing zeros in place of `None` when writing back the `i`th gradient.
- `_is_grad_none[i]` is set to `False` via `_reset_is_grad_none()`, which should be called in the post-backward. See the docstring for details.
- `_is_grad_none[i]` must be `False` in order to set `param.grad` to be a view into `flat_param.grad`.




[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 21, 2022
ghstack-source-id: e392f6c6752548932f41ae921773de1b941a77b4
Pull Request resolved: #87308
@awgu
Copy link
Contributor Author

awgu commented Oct 21, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

Hey @awgu.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

awgu added a commit to awgu/pytorch that referenced this pull request Oct 22, 2022
ghstack-source-id: e392f6c6752548932f41ae921773de1b941a77b4
Pull Request resolved: pytorch#87308
sgrigory pushed a commit to sgrigory/pytorch that referenced this pull request Oct 28, 2022
Some original parameters corresponding to one `FlatParameter` may have `None` gradient while others do not. In that case, the `flat_param.grad` must be non-`None`. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a `_is_grad_none` mask over the parameters' gradients.
- `_is_grad_none` is initialized to `False` for all.
- `_is_grad_none[i]` is set to `True` when writing zeros in place of `None` when writing back the `i`th gradient.
- `_is_grad_none[i]` is set to `False` via `_reset_is_grad_none()`, which should be called in the post-backward. See the docstring for details.
- `_is_grad_none[i]` must be `False` in order to set `param.grad` to be a view into `flat_param.grad`.

Pull Request resolved: pytorch#87308
Approved by: https://github.com/zhaojuanmao
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
Some original parameters corresponding to one `FlatParameter` may have `None` gradient while others do not. In that case, the `flat_param.grad` must be non-`None`. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a `_is_grad_none` mask over the parameters' gradients.
- `_is_grad_none` is initialized to `False` for all.
- `_is_grad_none[i]` is set to `True` when writing zeros in place of `None` when writing back the `i`th gradient.
- `_is_grad_none[i]` is set to `False` via `_reset_is_grad_none()`, which should be called in the post-backward. See the docstring for details.
- `_is_grad_none[i]` must be `False` in order to set `param.grad` to be a view into `flat_param.grad`.

Pull Request resolved: pytorch#87308
Approved by: https://github.com/zhaojuanmao
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Some original parameters corresponding to one `FlatParameter` may have `None` gradient while others do not. In that case, the `flat_param.grad` must be non-`None`. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a `_is_grad_none` mask over the parameters' gradients.
- `_is_grad_none` is initialized to `False` for all.
- `_is_grad_none[i]` is set to `True` when writing zeros in place of `None` when writing back the `i`th gradient.
- `_is_grad_none[i]` is set to `False` via `_reset_is_grad_none()`, which should be called in the post-backward. See the docstring for details.
- `_is_grad_none[i]` must be `False` in order to set `param.grad` to be a view into `flat_param.grad`.

Pull Request resolved: pytorch#87308
Approved by: https://github.com/zhaojuanmao
@facebook-github-bot facebook-github-bot deleted the gh/awgu/134/head branch June 8, 2023 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants