Skip to content

Commit

Permalink
Revert D26070147: [Gradient Compression] Refactor default_hooks.py an…
Browse files Browse the repository at this point in the history
…d powerSGD_hook.py by creating a util function that make a vanilla allreduce future

Test Plan: revert-hammer

Differential Revision:
D26070147 (e7b3496)

Original commit changeset: 8c9339f1511e

fbshipit-source-id: fa1e9582baec9759a73b3004be9bb19bdeb6cd34
  • Loading branch information
izdeby authored and facebook-github-bot committed Jan 29, 2021
1 parent 270111b commit 5a406c0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
25 changes: 10 additions & 15 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Expand Up @@ -2,20 +2,6 @@
import torch.distributed as dist


def allreduce_fut(
process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future:
group_to_use = process_group if process_group is not None else dist.group.WORLD

"Averages the input gradient tensor by allreduce and returns a future."
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()

def div_by_group_size(fut):
return [fut.value()[0].div_(group_to_use.size())]

return fut.then(div_by_group_size)


def allreduce_hook(
process_group: dist.ProcessGroup, bucket: dist._GradBucket
) -> torch.futures.Future:
Expand All @@ -31,7 +17,16 @@ def allreduce_hook(
Example::
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
return allreduce_fut(process_group, bucket.get_tensors()[0])
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()

tensor = bucket.get_tensors()[0]
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()

def then_callback(fut):
return [fut.value()[0].div_(world_size)]

return fut.then(then_callback)


def fp16_compress_hook(
Expand Down
12 changes: 9 additions & 3 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import torch
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default


def _orthogonalize(matrix, epsilon=1e-8):
Expand Down Expand Up @@ -128,7 +127,7 @@ def maybe_increase_iter(self, bucket):

if self.iter == self.start_powerSGD_iter:
logging.info(
"Start to apply PowerSGD after {} iterations.".format(self.iter)
"Starting to apply PowerSGD after {} iterations.".format(self.iter)
)


Expand Down Expand Up @@ -184,8 +183,15 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:

# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
fut = dist.all_reduce(
input_tensor, group=group_to_use, async_op=True
).get_future()

def div_callback(fut):
return [fut.value()[0].div_(world_size)]

state.maybe_increase_iter(bucket)
return default.allreduce_fut(group_to_use, input_tensor)
return fut.then(div_callback)

# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
Expand Down

0 comments on commit 5a406c0

Please sign in to comment.