-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Break up _post_backward_hook
into smaller funcs
#106068
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106068
Note: Links to docs will display an error until the docs builds have been completed. ✅ 1 Unrelated FailureAs of commit 143d12a with merge base 2b6249e (): UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 82d44c0b7611db29f6dafc822c92a2230fb4c931 Pull Request resolved: #106068
[ghstack-poisoned]
ghstack-source-id: 8f56bffc362a877e49582d68195d6f324224c496 Pull Request resolved: #106068
[ghstack-poisoned]
ghstack-source-id: d483b939f109ed31e15abc0541caa9ff54988844 Pull Request resolved: pytorch#106068
[ghstack-poisoned]
ghstack-source-id: 851ecbda86d1f295445a20f0b83bba273b62c8d2 Pull Request resolved: pytorch#106068
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 727f9fb29a1a9efea440c58b6bc338634ece752a Pull Request resolved: pytorch#106068
[ghstack-poisoned]
@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461) [ghstack-poisoned]
ghstack-source-id: 9ed700392040d8e7830128e6e450d7524d60d1ea Pull Request resolved: pytorch#106068
Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461) [ghstack-poisoned]
ghstack-source-id: 43c10535d50bb25f940732c38c14399651c5619d Pull Request resolved: pytorch#106068
ghstack-source-id: 43c10535d50bb25f940732c38c14399651c5619d Pull Request resolved: pytorch#106068
ghstack-source-id: 43c10535d50bb25f940732c38c14399651c5619d Pull Request resolved: pytorch#106068
@@ -951,6 +833,154 @@ def _should_free_in_backward( | |||
) | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you perhaps willing to add small top level comments for each of the new functions? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the comments in the subsequent PR since the async HSDP refactors these a bit more now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verified code motion only
The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views). The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half. Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability. Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461) [ghstack-poisoned]
The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views). The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half. Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability. Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461) [ghstack-poisoned]
Rebased |
**Overview** This PR runs the HSDP all-reduce as async so that it can overlap with both all-gather and reduce-scatter, which can lead to slight end-to-end speedups when the sharding process group is fully intra-node. Previously, the all-reduce serializes with reduce-scatter, so it can only overlap with one all-gather. For some clusters (e.g. our AWS cluster), `NCCL_CROSS_NIC=1` improves inter-node all-reduce times when overlapped with intra-node all-gather/reduce-scatter. **Experiment** <details> <summary> Example 'before' trace </summary> <img width="559" alt="hsdp_32gpus_old" src="https://github.com/pytorch/pytorch/assets/31054793/15222b6f-2b64-4e0b-a212-597335f05ba5"> </details> <details> <summary> Example 'after' trace </summary> <img width="524" alt="hsdp_32gpus_new" src="https://github.com/pytorch/pytorch/assets/31054793/94f63a1d-4255-4035-9e6e-9e10733f4e44"> </details> For the 6-encoder-layer, 6-decoder layer transformer with `d_model=8192`, `nhead=64` on 4 nodes / 32 40 GB A100s via AWS, the end-to-end iteration times are as follows (with AG == all-gather, RS == reduce-scatter, AR == all-reduce; bandwidth reported as algorithmic bandwidth): - Reference FSDP: - **1160 ms / iteration** - ~23 ms / encoder AG/RS --> 24.46 GB/s bandwidth - ~40 ms / decoder AG/RS --> 26.5 GB/s bandwidth - 50 GB/s theoretical inter-node bandwidth - Baseline 8-way HSDP (only overlap AR with AG) -- intra-node AG/RS, inter-node AR: - **665 ms / iteration** - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth - ~5 ms / decoder AG/RS --> 212 GB/s bandwidth - ~30 ms / encoder AR --> 2.34 GB/s bandwidth - ~55 ms / decoder AR --> 2.65 GB/s bandwidth - 300 GB/s theoretical intra-node bandwidth - New 8-way HSDP (overlap AR with AG and RS) -- intra-node AG/RS, inter-node AR: - **597 ms / iteration** - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth - ~6.2 ms / decoder AG/RS --> 170.97 GB/s bandwidth (slower) - ~23 ms / encoder AR (non-overlapped) --> 3.057 GB/s bandwidth (faster) - ~49 ms / decoder AR (non-overlapped) --> 2.70 GB/s bandwidth (faster) - ~100 ms / decoder AR (overlapped) --> 1.325 GB/s bandwidth (slower) - Overlapping with reduce-scatter reduces all-reduce bandwidth utilization even though the all-reduce is inter-node and reduce-scatter is intra-node! - New 8-way HSDP (overlap AR with AG and RS) with `NCCL_CROSS_NIC=1`: - **556 ms / iteration** - Speedup comes from faster overlapped AR Thus, for this particular workload, the async all-reduce enables 16% iteration-time speedup compared to the existing HSDP and 52% speedup compared to FSDP. These speedups are pronounced due to the workload being communication bound, so any communication time reduction translates directly to speedup. **Unit Test** This requires >= 4 GPUs: ``` python -m pytest test/distributed/fsdp/test_fsdp_hybrid_shard.py -k test_fsdp_hybrid_shard_parity ``` Differential Revision: [D47852456](https://our.internmc.facebook.com/intern/diff/D47852456) Pull Request resolved: #106080 Approved by: https://github.com/ezyang ghstack dependencies: #106068
Pull Request resolved: #107784 Approved by: https://github.com/fegin ghstack dependencies: #106068, #106080
ghstack-source-id: 4f0d016028bbc829f5df753db8b2768f251c879f Pull Request resolved: pytorch#106068
Stack from ghstack (oldest at bottom):
NCCL_CROSS_NIC=1
for HSDP #107784_post_backward_hook
into smaller funcs #106068The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {
NO_SHARD
,FULL_SHARD
/SHARD_GRAD_OP
,HYBRID_SHARD
/_HYBRID_SHARD_ZERO2
} plus some options like CPU offloading anduse_orig_params=True
(requiring using sharded gradient views).The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.
Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.
Differential Revision: D47852461