Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Fix input grad propagation when using param mixed precision #90921

Closed
wants to merge 5 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Dec 15, 2022

Stack from ghstack:

For parameter mixed precision, we cast the inputs to the low precision parameter dtype. If the input has tensors that require gradient, then we must cast them in place in order for them to receive a gradient. The cast should be tracked by autograd (e.g. with grad_fn equal to ToCopyBackward0). This removes the torch.no_grad context when calling _apply_to_tensors.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 15, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90921

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cd938ce:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu added a commit that referenced this pull request Dec 15, 2022
ghstack-source-id: 9f301e7252ae7e38cc4eecfcbe73ad3afb0dc459
Pull Request resolved: #90921
awgu added a commit that referenced this pull request Dec 15, 2022
ghstack-source-id: 30fbaa288f4cf524360f4e8db459d1fe09ae22b0
Pull Request resolved: #90921
awgu added a commit that referenced this pull request Dec 15, 2022
ghstack-source-id: 7b779515038cb08910611f663e7da589e98d0935
Pull Request resolved: #90921
@awgu awgu added the topic: bug fixes topic category label Dec 15, 2022
…recision"


For parameter mixed precision, we cast the inputs to the low precision parameter dtype. If the input has tensors that require gradient, then we must cast them in place in order for them to receive a gradient. Otherwise, the tensor that resulted from the out-of-place cast receives the gradient and is not in scope to the user. To preserve BC as much as possible, this PR only does the in-place cast if the tensor requires gradient.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Dec 15, 2022
ghstack-source-id: a6b8861faa3fcc0ee326c35ab0bcdc25508d688b
Pull Request resolved: #90921
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 15, 2022
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!


with torch.no_grad():
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
return x.to(dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I am a bit concerned about unforseen BC issues, but also our testing surface is quite solid. Did we consider doing this only for inputs that require grad to err on the safe side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if x does not require gradient, then x.to(dtype) will not be tracked by autograd. if x does require gradient, then x.to(dtype) will be tracked by autograd. This should be exactly the behavior we want. In other words, the casing should be handled naturally already if I understand correctly.

test/distributed/fsdp/test_fsdp_mixed_precision.py Outdated Show resolved Hide resolved
…recision"


For parameter mixed precision, we cast the inputs to the low precision parameter dtype. If the input has tensors that require gradient, then we must cast them in place in order for them to receive a gradient. The cast should be tracked by autograd (e.g. with `grad_fn` equal to `ToCopyBackward0`). This removes the `torch.no_grad` context when calling `_apply_to_tensors`.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Dec 15, 2022
ghstack-source-id: ad234ed096a8c598818af2dfe6b77ae2ef1ebd54
Pull Request resolved: #90921
@awgu
Copy link
Contributor Author

awgu commented Dec 15, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/awgu/280/head branch June 8, 2023 15:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants