Skip to content

Commit

Permalink
Update on "[FSDP] Add limiter using CUDA events"
Browse files Browse the repository at this point in the history
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
  • Loading branch information
awgu committed Aug 22, 2022
2 parents 7da9ceb + ea2b0b5 commit 3bcb4f5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
41 changes: 32 additions & 9 deletions test/distributed/fsdp/test_fsdp_grad_acc.py
Expand Up @@ -4,13 +4,16 @@
import itertools
import sys
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import distributed as dist
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
ShardingStrategy,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
CUDAInitMode,
Expand Down Expand Up @@ -88,6 +91,7 @@ def _test_grad_acc(
configs: List[_GradAccConfig],
cpu_offload: CPUOffload,
backward_prefetch: Optional[BackwardPrefetch],
sharding_strategy: ShardingStrategy,
):
"""
Tests gradient accumulation by comparing a run that trains sequentially
Expand All @@ -114,8 +118,10 @@ def _test_grad_acc(
"""
# Gradient accumulation outside `no_sync()` is not currently compatible
# with CPU offloading
if cpu_offload.offload_params and \
any(not config.use_no_sync for config in configs):
if (
cpu_offload.offload_params
and any(not config.use_no_sync for config in configs)
):
return
old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
try:
Expand All @@ -126,6 +132,7 @@ def _test_grad_acc(
fsdp_kwargs = {
"cpu_offload": cpu_offload,
"backward_prefetch": backward_prefetch,
"sharding_strategy": sharding_strategy,
}
fsdp_model: FSDP = TransformerWithSharedParams.init(
self.process_group,
Expand Down Expand Up @@ -210,6 +217,16 @@ def permute_tensor(x: torch.Tensor):
finally:
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32

def _get_subtest_config(self) -> Dict[str, List[Any]]:
"""Returns a subtest configuration that subtests prefetching."""
return {
"backward_prefetch": [
None,
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
]
}

@skip_if_lt_x_gpu(2)
@parametrize(
"configs",
Expand All @@ -231,14 +248,18 @@ def permute_tensor(x: torch.Tensor):
[CPUOffload(offload_params=False), CPUOffload(offload_params=True)],
)
@parametrize(
"backward_prefetch",
[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None],
"sharding_strategy",
[
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
]
)
def test_grad_acc(
self,
configs: _GradAccConfigs,
cpu_offload: CPUOffload,
backward_prefetch: Optional[BackwardPrefetch],
sharding_strategy: ShardingStrategy,
):
"""
Tests gradient accumulation.
Expand All @@ -255,11 +276,13 @@ def test_grad_acc(
manager is not currently compatible with CPU offloading, so those tests
are vacuous.
"""
self._test_grad_acc(
self.run_subtests(
self._get_subtest_config(),
self._test_grad_acc,
batch_dim=1,
configs=configs.configs,
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
sharding_strategy=sharding_strategy,
)


Expand Down
6 changes: 2 additions & 4 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Expand Up @@ -892,7 +892,8 @@ class FullyShardedDataParallel(nn.Module):
the sharding strategy is ``FULL_SHARD``, this represents the
maximum number of actively issued all-gathers. When this limit is
reached, FSDP blocks the CPU thread to ensure that some FSDP
parameters are freed before issuing further all-gathers.
parameters are freed before issuing further all-gathers. (Default:
``None``)
"""
def __init__(
Expand Down Expand Up @@ -3116,9 +3117,6 @@ def _post_backward_hook(
unsharded gradient.
- Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
gradient (accumulating with any existing gradient).
TODO (awgu): I am not sure if gradient accumulation without
``no_sync()`` works with ``NO_SHARD``.
"""
param = handle.flat_param
param._post_backward_called = True
Expand Down

0 comments on commit 3bcb4f5

Please sign in to comment.