Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update base for Update on "[FSDP][Proof of Concept] Add limiter using…
… CUDA events" ### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
- Loading branch information