Skip to content

Commit

Permalink
[Resubmission][Gradient Compression] Refactor default_hooks.py and po…
Browse files Browse the repository at this point in the history
…werSGD_hook.py by creating a util function that make a vanilla allreduce future (#51400)

Summary:
Pull Request resolved: #51400

Resubmission of #51094

Address #50973 (comment)

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 120725690

Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl

buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_default_ddp_comm_hooks_nccl

Reviewed By: rohan-varma

Differential Revision: D26162333

fbshipit-source-id: ccc2eae5383a23673e00d61cb5570fb8bf749cd0
  • Loading branch information
Yi Wang authored and facebook-github-bot committed Feb 1, 2021
1 parent 6c24296 commit 0831984
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
25 changes: 15 additions & 10 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@
import torch.distributed as dist


def _allreduce_fut(
process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future:
group_to_use = process_group if process_group is not None else dist.group.WORLD

"Averages the input gradient tensor by allreduce and returns a future."
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()

def div_by_group_size(fut):
return [fut.value()[0].div_(group_to_use.size())]

return fut.then(div_by_group_size)


def allreduce_hook(
process_group: dist.ProcessGroup, bucket: dist._GradBucket
) -> torch.futures.Future:
Expand All @@ -17,16 +31,7 @@ def allreduce_hook(
Example::
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()

tensor = bucket.get_tensors()[0]
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()

def then_callback(fut):
return [fut.value()[0].div_(world_size)]

return fut.then(then_callback)
return _allreduce_fut(process_group, bucket.get_tensors()[0])


def fp16_compress_hook(
Expand Down
17 changes: 7 additions & 10 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import torch.distributed as dist

from . import default_hooks as default


def _orthogonalize(matrix, epsilon=1e-8):
"""
Expand Down Expand Up @@ -127,7 +129,7 @@ def maybe_increase_iter(self, bucket):

if self.iter == self.start_powerSGD_iter:
logging.info(
"Starting to apply PowerSGD after {} iterations.".format(self.iter)
"Start to apply PowerSGD after {} iterations.".format(self.iter)
)


Expand Down Expand Up @@ -183,15 +185,8 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:

# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
fut = dist.all_reduce(
input_tensor, group=group_to_use, async_op=True
).get_future()

def div_callback(fut):
return [fut.value()[0].div_(world_size)]

state.maybe_increase_iter(bucket)
return fut.then(div_callback)
return default._allreduce_fut(group_to_use, input_tensor)

# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
Expand All @@ -210,7 +205,9 @@ def div_callback(fut):
total_length
)
)
state.error_dict[bucket_index] = torch.zeros(total_length, device=device, dtype=dtype)
state.error_dict[bucket_index] = torch.zeros(
total_length, device=device, dtype=dtype
)

# Keep a copy of the input tensor,
# so that we can compute the local error caused by compression later,
Expand Down

0 comments on commit 0831984

Please sign in to comment.