Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Refactor default_hooks.py and powerSGD_hook.py by creating a util function that make a vanilla allreduce future #51094

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 15 additions & 10 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@
import torch.distributed as dist


def allreduce_fut(
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future:
group_to_use = process_group if process_group is not None else dist.group.WORLD

"Averages the input gradient tensor by allreduce and returns a future."
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()

def div_by_group_size(fut):
return [fut.value()[0].div_(group_to_use.size())]

return fut.then(div_by_group_size)


def allreduce_hook(
process_group: dist.ProcessGroup, bucket: dist._GradBucket
) -> torch.futures.Future:
Expand All @@ -17,16 +31,7 @@ def allreduce_hook(
Example::
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()

tensor = bucket.get_tensors()[0]
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()

def then_callback(fut):
return [fut.value()[0].div_(world_size)]

return fut.then(then_callback)
return allreduce_fut(process_group, bucket.get_tensors()[0])


def fp16_compress_hook(
Expand Down
12 changes: 3 additions & 9 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default


def _orthogonalize(matrix, epsilon=1e-8):
Expand Down Expand Up @@ -127,7 +128,7 @@ def maybe_increase_iter(self, bucket):

if self.iter == self.start_powerSGD_iter:
logging.info(
"Starting to apply PowerSGD after {} iterations.".format(self.iter)
"Start to apply PowerSGD after {} iterations.".format(self.iter)
)


Expand Down Expand Up @@ -183,15 +184,8 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:

# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
fut = dist.all_reduce(
input_tensor, group=group_to_use, async_op=True
).get_future()

def div_callback(fut):
return [fut.value()[0].div_(world_size)]

state.maybe_increase_iter(bucket)
return fut.then(div_callback)
return default.allreduce_fut(group_to_use, input_tensor)

# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
Expand Down