diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index 11aec8ab8a61..f25f3a8caad8 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -15,13 +15,17 @@ def _ddp_comm_hook_wrapper(comm_hook, model, state): model.register_comm_hook(state, comm_hook) -def _powerSGD_comm_hook_wrapper(comm_hook, model, state, matrix_approximation_rank): +def _powerSGD_comm_hook_wrapper( + comm_hook, model, state, matrix_approximation_rank, random_seed=0 +): """ To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group, which will be wrapped up with other state info. """ powerSGD_state = powerSGD.PowerSGDState( - process_group=state, matrix_approximation_rank=matrix_approximation_rank + process_group=state, + matrix_approximation_rank=matrix_approximation_rank, + random_seed=random_seed, ) model.register_comm_hook(powerSGD_state, comm_hook) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 814a24cf262a..9a6fbb4a31dd 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -1,5 +1,6 @@ import math +import numpy as np import torch import torch.distributed as dist @@ -29,11 +30,17 @@ def _orthogonalize(matrix, epsilon=1e-8): class PowerSGDState(object): - __slots__ = ["process_group", "matrix_approximation_rank"] + __slots__ = ["process_group", "matrix_approximation_rank", "rng"] - def __init__(self, process_group, matrix_approximation_rank=1): + def __init__(self, process_group, matrix_approximation_rank=1, random_seed=0): self.process_group = process_group self.matrix_approximation_rank = matrix_approximation_rank + # The purpose of this RNG is to generate different random seeds for initializing Q across iterations, + # but in the same order for all the DDP replicas. + # Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps. + # If the same random projection is used, + # there will be differences between the gradients that are never synchronized. + self.rng = np.random.RandomState(random_seed) def powerSGD_hook( @@ -46,14 +53,15 @@ def powerSGD_hook( Once gradient tensors are aggregated across all workers, this hook applies compression as follows: 1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; - 2) Decomposes M into two low-rank tensors P and Q, + 2) Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; - 2) Allreduces P; - 3) Orthogonizes P; - 4) Compute Q, which is approximately equal to M^TP; - 5) Allreduces Q; - 6) Computes M, which is approximately equal to PQ^T. - 7) Truncates the input tensor to the original length. + 2) Computes P, which is equal to MQ; + 3) Allreduces P; + 4) Orthogonizes P; + 5) Computes Q, which is approximately equal to M^TP; + 6) Allreduces Q; + 7) Computes M, which is approximately equal to PQ^T. + 8) Truncates the input tensor to the original length. TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration -- one left multiplication and one right multiplication. @@ -92,25 +100,26 @@ def powerSGD_hook( input_tensor[total_length:padded_total_length].fill_(0) matrix = input_tensor.view(square_side_length, square_side_length) - def create_low_rank_tensor(fill_random_values): + def create_low_rank_tensor(fill_random_values, rng): "Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank." if fill_random_values: - with torch.random.fork_rng(devices=[device]): + with torch.random.fork_rng(devices=[]): + # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training. # The seed makes sure that the initial random values are the same across all the DDP replicas. # Such seed should differ at every step. - # Currently use the length of input tensor as the seed, which should be mostly different. - # TODO(wayi@): Should read the random seed from the state of this hook provided by the constructor. - torch.manual_seed(total_length) + # Since it is very slow to fork RNG state across all the CUDA devices, + # only fork on CPU and then move the generated tensor to the CUDA device. + torch.manual_seed(rng.randint(1_000_000_000)) return torch.randn( - square_side_length, state.matrix_approximation_rank, device=device - ) + square_side_length, state.matrix_approximation_rank, device="cpu" + ).to(device) else: return torch.empty( square_side_length, state.matrix_approximation_rank, device=device ) - p = create_low_rank_tensor(fill_random_values=False) - q = create_low_rank_tensor(fill_random_values=True) + p = create_low_rank_tensor(fill_random_values=False, rng=state.rng) + q = create_low_rank_tensor(fill_random_values=True, rng=state.rng) _orthogonalize(q, 0) torch.matmul(matrix, q, out=p)