Skip to content

Commit

Permalink
[Gradient Compression] Refactor default_hooks.py and powerSGD_hook.py…
Browse files Browse the repository at this point in the history
… by creating a util function that make a vanilla allreduce future

Pull Request resolved: #51094

Address #50973 (comment)

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 120376248

Differential Revision: [D26070147](https://our.internmc.facebook.com/intern/diff/D26070147/)
  • Loading branch information
wayi committed Jan 26, 2021
1 parent 1d697a8 commit af4a348
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
21 changes: 13 additions & 8 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@
import torch.distributed as dist


def allreduce_fut(
process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future:
"Averages the input gradient tensor by allreduce and returns a future."
fut = dist.all_reduce(tensor, group=process_group, async_op=True).get_future()

def div_by_group_size(fut):
return [fut.value()[0].div_(process_group.size())]

return fut.then(div_by_group_size)


def allreduce_hook(
process_group: dist.ProcessGroup, bucket: dist._GradBucket
) -> torch.futures.Future:
Expand All @@ -18,15 +30,8 @@ def allreduce_hook(
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()

tensor = bucket.get_tensors()[0]
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()

def then_callback(fut):
return [fut.value()[0].div_(world_size)]

return fut.then(then_callback)
return allreduce_fut(group_to_use, bucket.get_tensors()[0])


def fp16_compress_hook(
Expand Down
12 changes: 3 additions & 9 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default


def _orthogonalize(matrix, epsilon=1e-8):
Expand Down Expand Up @@ -126,7 +127,7 @@ def maybe_increase_iter(self, bucket):

if self.iter == self.start_powerSGD_iter:
logging.info(
"Starting to apply PowerSGD after {} iterations.".format(self.iter)
"Start to apply PowerSGD after {} iterations.".format(self.iter)
)


Expand Down Expand Up @@ -178,15 +179,8 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:

# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
fut = dist.all_reduce(
input_tensor, group=group_to_use, async_op=True
).get_future()

def div_callback(fut):
return [fut.value()[0].div_(world_size)]

state.maybe_increase_iter(bucket)
return fut.then(div_callback)
return default.allreduce_fut(group_to_use, input_tensor)

# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
Expand Down

0 comments on commit af4a348

Please sign in to comment.