Skip to content

Commit

Permalink
[Gradient Compression] Add a random generator to PowerSGD state for i…
Browse files Browse the repository at this point in the history
…nitializing low-rank matrix Q

Pull Request resolved: #48507

Previously the random seed is the length of input tensor, which is not guaranteed to be the different for different batches. Now initialize a random generator in PowerSGD state, and use this generator to create a random seed to randomize the low-rank tensor Q at every step.

Therefore, the initial tensor Q should be the same across all the replicas at the same step, but different at different steps.

'torch.manual_seed' is used in the same way as https://github.com/epfml/powersgd/blob/master/gradient_reducers.py#L675

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 117402238

Differential Revision: [D25191589](https://our.internmc.facebook.com/intern/diff/D25191589/)
  • Loading branch information
wayi committed Nov 27, 2020
1 parent c5ce995 commit 2c49736
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
8 changes: 6 additions & 2 deletions torch/distributed/algorithms/ddp_comm_hooks/__init__.py
Expand Up @@ -15,13 +15,17 @@ def _ddp_comm_hook_wrapper(comm_hook, model, state):
model.register_comm_hook(state, comm_hook)


def _powerSGD_comm_hook_wrapper(comm_hook, model, state, matrix_approximation_rank):
def _powerSGD_comm_hook_wrapper(
comm_hook, model, state, matrix_approximation_rank, random_seed=0
):
"""
To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
which will be wrapped up with other state info.
"""
powerSGD_state = powerSGD.PowerSGDState(
process_group=state, matrix_approximation_rank=matrix_approximation_rank
process_group=state,
matrix_approximation_rank=matrix_approximation_rank,
random_seed=random_seed,
)
model.register_comm_hook(powerSGD_state, comm_hook)

Expand Down
39 changes: 21 additions & 18 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@@ -1,5 +1,6 @@
import math

import numpy as np
import torch
import torch.distributed as dist

Expand Down Expand Up @@ -29,11 +30,12 @@ def _orthogonalize(matrix, epsilon=1e-8):


class PowerSGDState(object):
__slots__ = ["process_group", "matrix_approximation_rank"]
__slots__ = ["process_group", "matrix_approximation_rank", "rng"]

def __init__(self, process_group, matrix_approximation_rank=1):
def __init__(self, process_group, matrix_approximation_rank=1, random_seed=0):
self.process_group = process_group
self.matrix_approximation_rank = matrix_approximation_rank
self.rng = np.random.RandomState(random_seed)


def powerSGD_hook(
Expand All @@ -46,14 +48,15 @@ def powerSGD_hook(
Once gradient tensors are aggregated across all workers, this hook applies
compression as follows:
1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
2) Decomposes M into two low-rank tensors P and Q,
2) Creates two low-rank tensors P and Q for decomposing M,
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
2) Allreduces P;
3) Orthogonizes P;
4) Compute Q, which is approximately equal to M^TP;
5) Allreduces Q;
6) Computes M, which is approximately equal to PQ^T.
7) Truncates the input tensor to the original length.
2) Computes P, which is equal to MQ;
3) Allreduces P;
4) Orthogonizes P;
5) Computes Q, which is approximately equal to M^TP;
6) Allreduces Q;
7) Computes M, which is approximately equal to PQ^T.
8) Truncates the input tensor to the original length.
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
one left multiplication and one right multiplication.
Expand Down Expand Up @@ -92,25 +95,25 @@ def powerSGD_hook(
input_tensor[total_length:padded_total_length].fill_(0)
matrix = input_tensor.view(square_side_length, square_side_length)

def create_low_rank_tensor(fill_random_values):
def create_low_rank_tensor(fill_random_values, rng):
"Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank."
if fill_random_values:
with torch.random.fork_rng(devices=[device]):
with torch.random.fork_rng(devices=[]):
# The seed makes sure that the initial random values are the same across all the DDP replicas.
# Such seed should differ at every step.
# Currently use the length of input tensor as the seed, which should be mostly different.
# TODO(wayi@): Should read the random seed from the state of this hook provided by the constructor.
torch.manual_seed(total_length)
# Since it is very slow to fork RNG state across all the CUDA devices,
# only fork on CPU and then move the generated tensor to the CUDA device.
torch.manual_seed(rng.randint(1_000_000_000))
return torch.randn(
square_side_length, state.matrix_approximation_rank, device=device
)
square_side_length, state.matrix_approximation_rank, device="cpu"
).to(device)
else:
return torch.empty(
square_side_length, state.matrix_approximation_rank, device=device
)

p = create_low_rank_tensor(fill_random_values=False)
q = create_low_rank_tensor(fill_random_values=True)
p = create_low_rank_tensor(fill_random_values=False, rng=state.rng)
q = create_low_rank_tensor(fill_random_values=True, rng=state.rng)
_orthogonalize(q, 0)

torch.matmul(matrix, q, out=p)
Expand Down

0 comments on commit 2c49736

Please sign in to comment.