Skip to content

Commit

Permalink
Update on "[Gradient Compression] Refactor default_hooks.py and power…
Browse files Browse the repository at this point in the history
…SGD_hook.py by creating a util function that make a vanilla allreduce future"


Address #50973 (comment)

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26070147](https://our.internmc.facebook.com/intern/diff/D26070147/)

[ghstack-poisoned]
  • Loading branch information
wayi committed Jan 29, 2021
1 parent be1af96 commit 2ca21e5
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
def allreduce_fut(
process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future:
group_to_use = process_group if process_group is not None else dist.group.WORLD

"Averages the input gradient tensor by allreduce and returns a future."
fut = dist.all_reduce(tensor, group=process_group, async_op=True).get_future()
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()

def div_by_group_size(fut):
return [fut.value()[0].div_(process_group.size())]
return [fut.value()[0].div_(group_to_use.size())]

return fut.then(div_by_group_size)

Expand All @@ -29,9 +31,7 @@ def allreduce_hook(
Example::
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD

return allreduce_fut(group_to_use, bucket.get_tensors()[0])
return allreduce_fut(process_group, bucket.get_tensors()[0])


def fp16_compress_hook(
Expand Down

0 comments on commit 2ca21e5

Please sign in to comment.