diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index c6f981de772d..59491a868be4 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -2,20 +2,6 @@ import torch.distributed as dist -def allreduce_fut( - process_group: dist.ProcessGroup, tensor: torch.Tensor -) -> torch.futures.Future: - group_to_use = process_group if process_group is not None else dist.group.WORLD - - "Averages the input gradient tensor by allreduce and returns a future." - fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future() - - def div_by_group_size(fut): - return [fut.value()[0].div_(group_to_use.size())] - - return fut.then(div_by_group_size) - - def allreduce_hook( process_group: dist.ProcessGroup, bucket: dist._GradBucket ) -> torch.futures.Future: @@ -31,7 +17,16 @@ def allreduce_hook( Example:: >>> ddp_model.register_comm_hook(process_group, allreduce_hook) """ - return allreduce_fut(process_group, bucket.get_tensors()[0]) + group_to_use = process_group if process_group is not None else dist.group.WORLD + world_size = group_to_use.size() + + tensor = bucket.get_tensors()[0] + fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future() + + def then_callback(fut): + return [fut.value()[0].div_(world_size)] + + return fut.then(then_callback) def fp16_compress_hook( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 0961d4bd656f..db7a5a50f51b 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -4,7 +4,6 @@ import numpy as np import torch import torch.distributed as dist -import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default def _orthogonalize(matrix, epsilon=1e-8): @@ -128,7 +127,7 @@ def maybe_increase_iter(self, bucket): if self.iter == self.start_powerSGD_iter: logging.info( - "Start to apply PowerSGD after {} iterations.".format(self.iter) + "Starting to apply PowerSGD after {} iterations.".format(self.iter) ) @@ -184,8 +183,15 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. if state.iter < state.start_powerSGD_iter: + fut = dist.all_reduce( + input_tensor, group=group_to_use, async_op=True + ).get_future() + + def div_callback(fut): + return [fut.value()[0].div_(world_size)] + state.maybe_increase_iter(bucket) - return default.allreduce_fut(group_to_use, input_tensor) + return fut.then(div_callback) # Apply PowerSGD after `start_powerSGD_iter` iterations. device = input_tensor.device