Skip to content

Commit

Permalink
[FSDP] Refactor comm hooks directory
Browse files Browse the repository at this point in the history
Pull Request resolved: #85496

Move FSDP comm hooks into FSDP folder, expose the reduced precision
hooks in torch.distributed.fsdp.default_comm_hooks namespace
ghstack-source-id: 168203289

Differential Revision: [D39743042](https://our.internmc.facebook.com/intern/diff/D39743042/)
  • Loading branch information
rohan-varma committed Sep 22, 2022
1 parent 3805af4 commit 2b24328
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 27 deletions.
3 changes: 3 additions & 0 deletions docs/source/fsdp.rst
Expand Up @@ -26,3 +26,6 @@ FullyShardedDataParallel

.. autoclass:: torch.distributed.fsdp.OptimStateKeyType
:members:

.. autoclass:: torch.distributed.fsdp.default_comm_hooks
:members:
45 changes: 25 additions & 20 deletions test/distributed/fsdp/test_fsdp_comm_hooks.py
Expand Up @@ -8,7 +8,7 @@
import torch.nn.functional as F
from torch import distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.algorithms._comm_hooks import default_hooks
from torch.distributed.fsdp import default_comm_hooks
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
Expand Down Expand Up @@ -116,6 +116,21 @@ def dummy_hook_for_sharded_fsdp(self, state: DummyState, grad: torch.Tensor, out
output, grad, group=state.process_group
)

def _get_default_hook(sharding_strategy):
return (
default_comm_hooks.reduce_scatter_hook
if sharding_strategy != ShardingStrategy.NO_SHARD else
default_comm_hooks.allreduce_hook
)

def _get_dummy_hook(sharding_strategy):
return (
DummyHook.dummy_hook_for_no_shard_fsdp
if sharding_strategy != ShardingStrategy.NO_SHARD
else DummyHook.dummy_hook_for_sharded_fsdp
)


class TestCommunicationHooks(FSDPTest):

@skip_if_lt_x_gpu(2)
Expand Down Expand Up @@ -153,9 +168,7 @@ def test_default_communication_hook_behavior(

# Check that default hook is set to `all_reduce` for `NO_SHARD`
# or `reduce_scatter` for sharded cases
default_hook = default_hooks.reduce_scatter_hook\
if sharding_strategy != ShardingStrategy.NO_SHARD\
else default_hooks.allreduce_hook
default_hook = _get_default_hook(sharding_strategy)

for entry in FSDP.fsdp_modules(net_default_hook):
self.assertEqual(entry._communication_hook, default_hook)
Expand Down Expand Up @@ -224,17 +237,13 @@ def test_default_communication_hook_initialization(

# Check that default hook is set to `all_reduce` for `NO_SHARD`
# or `reduce_scatter` for sharded cases
default_hook = default_hooks.reduce_scatter_hook\
if sharding_strategy != ShardingStrategy.NO_SHARD\
else default_hooks.allreduce_hook
default_hook = _get_default_hook(sharding_strategy)

for entry in FSDP.fsdp_modules(fsdp_model_with_hook):
self.assertEqual(entry._communication_hook, default_hook)

dummy_state = DummyState(process_group=None, noise=1234)
dummy_hook = DummyHook.dummy_hook_for_no_shard_fsdp\
if sharding_strategy != ShardingStrategy.NO_SHARD\
else DummyHook.dummy_hook_for_sharded_fsdp
dummy_hook = _get_dummy_hook(sharding_strategy)

fsdp_model_with_hook.register_comm_hook(
dummy_state,
Expand Down Expand Up @@ -304,9 +313,7 @@ def test_registering_hook_non_root(
sharding_strategy=sharding_strategy
)
dummy_state = DummyState(process_group=None, noise=1234)
dummy_hook = DummyHook.dummy_hook_for_no_shard_fsdp\
if sharding_strategy != ShardingStrategy.NO_SHARD\
else DummyHook.dummy_hook_for_sharded_fsdp
dummy_hook = _get_dummy_hook(sharding_strategy)
# Creating a list of non-root submodules to test
submodules = self._get_submodules(fsdp_model_with_hook)
# Check that assertion is raised for registering a comm hook on a non-root
Expand Down Expand Up @@ -339,9 +346,7 @@ def test_registering_hook_submodules(
sharding_strategy=sharding_strategy
)
dummy_state = DummyState(process_group=None, noise=1234)
dummy_hook = DummyHook.dummy_hook_for_no_shard_fsdp\
if sharding_strategy != ShardingStrategy.NO_SHARD\
else DummyHook.dummy_hook_for_sharded_fsdp
dummy_hook = _get_dummy_hook(sharding_strategy)
submodules = self._get_submodules(fsdp_model_with_hook)

# Simulate a registration of a hook on a submodule
Expand Down Expand Up @@ -422,8 +427,8 @@ def test_fp16_hook(
sharding_strategy: Optional[ShardingStrategy]
):

state = default_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_hooks.fp16_compress_hook
state = default_comm_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_comm_hooks.fp16_compress_hook

self._check_low_precision_hook(state, hook, sharding_strategy, torch.float16, has_wrapping)

Expand All @@ -449,8 +454,8 @@ def test_bf16_hook(
sharding_strategy: Optional[ShardingStrategy]
):

state = default_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_hooks.bf16_compress_hook
state = default_comm_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_comm_hooks.bf16_compress_hook

self._check_low_precision_hook(state, hook, sharding_strategy, torch.bfloat16, has_wrapping)

Expand Down
7 changes: 0 additions & 7 deletions torch/distributed/algorithms/_comm_hooks/__init__.py

This file was deleted.

2 changes: 2 additions & 0 deletions torch/distributed/fsdp/__init__.py
Expand Up @@ -11,3 +11,5 @@
StateDictType,
)
from .wrap import ParamExecOrderWrapPolicy

from .comm import default_hooks as default_comm_hooks
7 changes: 7 additions & 0 deletions torch/distributed/fsdp/comm/__init__.py
@@ -0,0 +1,7 @@

from . import default_hooks

LOW_PRECISION_HOOKS = [
default_hooks.fp16_compress_hook,
default_hooks.bf16_compress_hook,
]

0 comments on commit 2b24328

Please sign in to comment.