Skip to content

Commit

Permalink
Update base for Update on "[FSDP][Proof of Concept] Add limiter using…
Browse files Browse the repository at this point in the history
… CUDA events"


### High-GPU Reserved Memory
As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP).

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### [WIP] PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes

#### Observations & Discussion
1. Blocking the CPU thread is critical for preventing the over-all-gathering.
2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order.
    - Fairscale and PT-D FSDPs all follow this.
    - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order.
    - We should investigate the performance difference between pre- and post-forward prefetching.
        - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed.
3. For static graphs, backward prefetching should use the pre-backward order.
4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter.
    - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation.
    - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1).
    - We may want to add some logic to guard against case (3).

#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `allow_over_all_gather=True` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `allow_over_all_gather=False` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
  • Loading branch information
awgu committed Aug 22, 2022
1 parent 9b00f6b commit feb7875
Showing 0 changed files with 0 additions and 0 deletions.

0 comments on commit feb7875

Please sign in to comment.