Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Add a random generator to PowerSGD state for initializing low-rank matrix Q #48507

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions torch/distributed/algorithms/ddp_comm_hooks/__init__.py
Expand Up @@ -15,13 +15,17 @@ def _ddp_comm_hook_wrapper(comm_hook, model, state):
model.register_comm_hook(state, comm_hook)


def _powerSGD_comm_hook_wrapper(comm_hook, model, state, matrix_approximation_rank):
def _powerSGD_comm_hook_wrapper(
comm_hook, model, state, matrix_approximation_rank, random_seed=0
):
"""
To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
which will be wrapped up with other state info.
"""
powerSGD_state = powerSGD.PowerSGDState(
process_group=state, matrix_approximation_rank=matrix_approximation_rank
process_group=state,
matrix_approximation_rank=matrix_approximation_rank,
random_seed=random_seed,
)
model.register_comm_hook(powerSGD_state, comm_hook)

Expand Down
45 changes: 27 additions & 18 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@@ -1,5 +1,6 @@
import math

import numpy as np
import torch
import torch.distributed as dist

Expand Down Expand Up @@ -29,11 +30,17 @@ def _orthogonalize(matrix, epsilon=1e-8):


class PowerSGDState(object):
__slots__ = ["process_group", "matrix_approximation_rank"]
__slots__ = ["process_group", "matrix_approximation_rank", "rng"]

def __init__(self, process_group, matrix_approximation_rank=1):
def __init__(self, process_group, matrix_approximation_rank=1, random_seed=0):
self.process_group = process_group
self.matrix_approximation_rank = matrix_approximation_rank
# The purpose of this RNG is to generate different random seeds for initializing Q across iterations,
# but in the same order for all the DDP replicas.
# Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps.
# If the same random projection is used,
# there will be differences between the gradients that are never synchronized.
self.rng = np.random.RandomState(random_seed)


def powerSGD_hook(
Expand All @@ -46,14 +53,15 @@ def powerSGD_hook(
Once gradient tensors are aggregated across all workers, this hook applies
compression as follows:
1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
2) Decomposes M into two low-rank tensors P and Q,
2) Creates two low-rank tensors P and Q for decomposing M,
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
2) Allreduces P;
3) Orthogonizes P;
4) Compute Q, which is approximately equal to M^TP;
5) Allreduces Q;
6) Computes M, which is approximately equal to PQ^T.
7) Truncates the input tensor to the original length.
2) Computes P, which is equal to MQ;
3) Allreduces P;
4) Orthogonizes P;
5) Computes Q, which is approximately equal to M^TP;
6) Allreduces Q;
7) Computes M, which is approximately equal to PQ^T.
8) Truncates the input tensor to the original length.

TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
one left multiplication and one right multiplication.
Expand Down Expand Up @@ -92,25 +100,26 @@ def powerSGD_hook(
input_tensor[total_length:padded_total_length].fill_(0)
matrix = input_tensor.view(square_side_length, square_side_length)

def create_low_rank_tensor(fill_random_values):
def create_low_rank_tensor(fill_random_values, rng):
"Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank."
if fill_random_values:
with torch.random.fork_rng(devices=[device]):
with torch.random.fork_rng(devices=[]):
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
# Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
# The seed makes sure that the initial random values are the same across all the DDP replicas.
# Such seed should differ at every step.
# Currently use the length of input tensor as the seed, which should be mostly different.
# TODO(wayi@): Should read the random seed from the state of this hook provided by the constructor.
torch.manual_seed(total_length)
# Since it is very slow to fork RNG state across all the CUDA devices,
# only fork on CPU and then move the generated tensor to the CUDA device.
torch.manual_seed(rng.randint(1_000_000_000))
return torch.randn(
square_side_length, state.matrix_approximation_rank, device=device
)
square_side_length, state.matrix_approximation_rank, device="cpu"
).to(device)
else:
return torch.empty(
square_side_length, state.matrix_approximation_rank, device=device
)

p = create_low_rank_tensor(fill_random_values=False)
q = create_low_rank_tensor(fill_random_values=True)
p = create_low_rank_tensor(fill_random_values=False, rng=state.rng)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
q = create_low_rank_tensor(fill_random_values=True, rng=state.rng)
_orthogonalize(q, 0)

torch.matmul(matrix, q, out=p)
Expand Down