Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Break up _post_backward_hook into smaller funcs #106068

Closed
wants to merge 13 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jul 26, 2023

Stack from ghstack (oldest at bottom):

The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {NO_SHARD, FULL_SHARD/SHARD_GRAD_OP, HYBRID_SHARD/_HYBRID_SHARD_ZERO2} plus some options like CPU offloading and use_orig_params=True (requiring using sharded gradient views).

The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.

Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.

Differential Revision: D47852461

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 26, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106068

Note: Links to docs will display an error until the docs builds have been completed.

✅ 1 Unrelated Failure

As of commit 143d12a with merge base 2b6249e (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Jul 26, 2023
awgu added a commit that referenced this pull request Jul 26, 2023
ghstack-source-id: 82d44c0b7611db29f6dafc822c92a2230fb4c931
Pull Request resolved: #106068
awgu added a commit that referenced this pull request Jul 26, 2023
ghstack-source-id: 8f56bffc362a877e49582d68195d6f324224c496
Pull Request resolved: #106068
awgu added a commit to awgu/pytorch that referenced this pull request Jul 26, 2023
ghstack-source-id: d483b939f109ed31e15abc0541caa9ff54988844
Pull Request resolved: pytorch#106068
awgu added a commit to awgu/pytorch that referenced this pull request Jul 26, 2023
ghstack-source-id: 851ecbda86d1f295445a20f0b83bba273b62c8d2
Pull Request resolved: pytorch#106068
awgu added a commit to awgu/pytorch that referenced this pull request Jul 27, 2023
ghstack-source-id: 727f9fb29a1a9efea440c58b6bc338634ece752a
Pull Request resolved: pytorch#106068
@awgu
Copy link
Contributor Author

awgu commented Jul 27, 2023

@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

awgu added a commit to awgu/pytorch that referenced this pull request Jul 28, 2023
ghstack-source-id: 9ed700392040d8e7830128e6e450d7524d60d1ea
Pull Request resolved: pytorch#106068
awgu added a commit to awgu/pytorch that referenced this pull request Aug 1, 2023
ghstack-source-id: 43c10535d50bb25f940732c38c14399651c5619d
Pull Request resolved: pytorch#106068
awgu added a commit to awgu/pytorch that referenced this pull request Aug 1, 2023
ghstack-source-id: 43c10535d50bb25f940732c38c14399651c5619d
Pull Request resolved: pytorch#106068
awgu added a commit to awgu/pytorch that referenced this pull request Aug 21, 2023
ghstack-source-id: 43c10535d50bb25f940732c38c14399651c5619d
Pull Request resolved: pytorch#106068
@awgu awgu added the topic: not user facing topic category label Aug 21, 2023
@@ -951,6 +833,154 @@ def _should_free_in_backward(
)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you perhaps willing to add small top level comments for each of the new functions? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the comments in the subsequent PR since the async HSDP refactors these a bit more now.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified code motion only

The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views).

The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.

Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.


Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)

[ghstack-poisoned]
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 23, 2023
The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views).

The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.

Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.


Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Rebased gh/awgu/443/orig onto refs/remotes/origin/viable/strict because #107784 was rebased, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/106068)

pytorchmergebot pushed a commit that referenced this pull request Aug 23, 2023
**Overview**
This PR runs the HSDP all-reduce as async so that it can overlap with both all-gather and reduce-scatter, which can lead to slight end-to-end speedups when the sharding process group is fully intra-node. Previously, the all-reduce serializes with reduce-scatter, so it can only overlap with one all-gather.

For some clusters (e.g. our AWS cluster), `NCCL_CROSS_NIC=1` improves inter-node all-reduce times when overlapped with intra-node all-gather/reduce-scatter.

**Experiment**
<details>
<summary> Example 'before' trace </summary>
<img width="559" alt="hsdp_32gpus_old" src="https://github.com/pytorch/pytorch/assets/31054793/15222b6f-2b64-4e0b-a212-597335f05ba5">

</details>

<details>
<summary> Example 'after' trace </summary>
<img width="524" alt="hsdp_32gpus_new" src="https://github.com/pytorch/pytorch/assets/31054793/94f63a1d-4255-4035-9e6e-9e10733f4e44">

</details>

For the 6-encoder-layer, 6-decoder layer transformer with `d_model=8192`, `nhead=64` on 4 nodes / 32 40 GB A100s via AWS, the end-to-end iteration times are as follows (with AG == all-gather, RS == reduce-scatter, AR == all-reduce; bandwidth reported as algorithmic bandwidth):
- Reference FSDP:
    - **1160 ms / iteration**
    - ~23 ms / encoder AG/RS --> 24.46 GB/s bandwidth
    - ~40 ms / decoder AG/RS --> 26.5 GB/s bandwidth
    - 50 GB/s theoretical inter-node bandwidth
- Baseline 8-way HSDP (only overlap AR with AG) -- intra-node AG/RS, inter-node AR:
    - **665 ms / iteration**
    - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
    - ~5 ms / decoder AG/RS --> 212 GB/s bandwidth
    - ~30 ms / encoder AR --> 2.34 GB/s bandwidth
    - ~55 ms / decoder AR --> 2.65 GB/s bandwidth
    - 300 GB/s theoretical intra-node bandwidth
- New 8-way HSDP (overlap AR with AG and RS) -- intra-node AG/RS, inter-node AR:
    - **597 ms / iteration**
    - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
    - ~6.2 ms / decoder AG/RS --> 170.97 GB/s bandwidth (slower)
    - ~23 ms / encoder AR (non-overlapped) --> 3.057 GB/s bandwidth (faster)
    - ~49 ms / decoder AR (non-overlapped) --> 2.70 GB/s bandwidth (faster)
    - ~100 ms / decoder AR (overlapped) --> 1.325 GB/s bandwidth (slower)
    - Overlapping with reduce-scatter reduces all-reduce bandwidth utilization even though the all-reduce is inter-node and reduce-scatter is intra-node!
- New 8-way HSDP (overlap AR with AG and RS) with `NCCL_CROSS_NIC=1`:
    - **556 ms / iteration**
    - Speedup comes from faster overlapped AR

Thus, for this particular workload, the async all-reduce enables 16% iteration-time speedup compared to the existing HSDP and 52% speedup compared to FSDP. These speedups are pronounced due to the workload being communication bound, so any communication time reduction translates directly to speedup.

**Unit Test**
This requires >= 4 GPUs:
```
python -m pytest test/distributed/fsdp/test_fsdp_hybrid_shard.py -k test_fsdp_hybrid_shard_parity
```

Differential Revision: [D47852456](https://our.internmc.facebook.com/intern/diff/D47852456)
Pull Request resolved: #106080
Approved by: https://github.com/ezyang
ghstack dependencies: #106068
pytorchmergebot pushed a commit that referenced this pull request Aug 23, 2023
awgu added a commit to awgu/pytorch that referenced this pull request Aug 23, 2023
ghstack-source-id: 4f0d016028bbc829f5df753db8b2768f251c879f
Pull Request resolved: pytorch#106068
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants