-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][Perf] Do not call pad
in no-padding case
#88769
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88769
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1761e6a: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: da8c8378455e723a42b6d4b23145b1536331eee9 Pull Request resolved: #88769
- Calling `F.pad()` issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case. - This PR also does not zero the newly-allocated sharded gradient tensor before the reduce-scatter if `use_orig_params=True` because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. For `use_orig_params=False`, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that. [ghstack-poisoned]
ghstack-source-id: 4a1cca1ef167964173892bc16cf54a5716c0191e Pull Request resolved: #88769
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
- Calling `F.pad()` issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case. - This PR also does not zero the newly-allocated sharded gradient tensor before the reduce-scatter if `use_orig_params=True` because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. For `use_orig_params=False`, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that. [ghstack-poisoned]
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
ghstack-source-id: da8c8378455e723a42b6d4b23145b1536331eee9 Pull Request resolved: pytorch#88769
- Calling `F.pad()` issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case. - This PR also does not zero the newly-allocated sharded gradient tensor before the reduce-scatter if `use_orig_params=True` because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. For `use_orig_params=False`, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that. [ghstack-poisoned]
ghstack-source-id: 0f8c479efccb74b0f1f0e43e940bc85de774e0a5 Pull Request resolved: pytorch#88769
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This allocation happened previously in the post-backward stream, which induced cross-stream memory fragmentation. (Only the sharded gradient needs to be allocated in the post-backward stream, not the unsharded gradient.) For T5-11B on 2 nodes and batch size 6, eliminating the unnecessary |
- Calling `F.pad()` issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case. - This PR also does not zero the newly-allocated sharded gradient tensor before the reduce-scatter if `use_orig_params=True` because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. For `use_orig_params=False`, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that. Pull Request resolved: pytorch#88769 Approved by: https://github.com/zhaojuanmao
Stack from ghstack:
ModuleWrapPolicy
#88453 [Dynamo][FSDP] Migrate toModuleWrapPolicy
ModuleWrapPolicy
for simplicity #88450 [FSDP] IntroduceModuleWrapPolicy
for simplicitypad
in no-padding case #88769 [FSDP][Perf] Do not callpad
in no-padding caseF.pad()
issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case.use_orig_params=True
because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. Foruse_orig_params=False
, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that.