Skip to content


Revert D25511543: [Gradient Compression] Implement the original layer…
Browse files Browse the repository at this point in the history
…wise PowerSGD

Test Plan: revert-hammer

Differential Revision:
D25511543 (71f3399)

Original commit changeset: 19ef188bc2d4

fbshipit-source-id: a363641a059aeacc57684884998cf8fb7363d748
  • Loading branch information
mrshenli authored and facebook-github-bot committed Dec 19, 2020
1 parent 5cde23f commit ad9923e
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 243 deletions.
11 changes: 0 additions & 11 deletions torch/distributed/algorithms/ddp_comm_hooks/
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,6 @@ class DDPCommHookType(Enum):
# Batching can lead to a faster training at the cost of accuracy.

def register_ddp_comm_hook(
Expand Down
187 changes: 2 additions & 185 deletions torch/distributed/algorithms/ddp_comm_hooks/
Original file line number Diff line number Diff line change
Expand Up @@ -73,195 +73,12 @@ def __init__(
def powerSGD_hook(
state: PowerSGDState,
) -> torch.futures.Future:
This DDP communication hook implements the original PowerSGD gradient compression
algorithm described in
Once gradient tensors are aggregated across all workers, this hook applies
compression as follows:
1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors:
high-rank tensors and vector-like rank-1 tensors (for biases).
2) Handles rank-1 tensors by allreducing them without compression:
2.1) Allocate contiguous memory for those rank-1 tensors,
and allreduces all the rank-1 tensors as a batch, without compression;
2.2) Copies the indvidual rank-1 tensors from the contiguous memory back to the input tensor.
3) Handles high-rank tensors by PowerSGD compression:
3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
3.2) Computes each P in Ps, which is equal to MQ;
3.3) Allreduces Ps as a batch;
3.4) Orthogonizes each P in Ps;
3.5) Computes each Q in Qs, which is approximately equal to M^TP;
3.6) Allreduces Qs as a batch;
3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
one left multiplication and one right multiplication.
For warm start, can take one such step at a time, and alternate between them.
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode at this time,
only exactly one tensor is stored in this bucket.
Future handler of the communication, which updates the gradients in place.
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
process_group = state.process_group
group_to_use = process_group if process_group is not None else
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]
device = input_tensor.device
dtype = input_tensor.dtype
# Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
tensors = [
input_tensor[offset : offset + length].view(sizes)
for offset, length, sizes in zip(
bucket.get_offsets(), bucket.get_lengths(), bucket.get_sizes_list()

# Step I: Handle rank-1 tensors.
# Allocate contiguous memory for rank-1 tensors to allreduce them without compression efficiently.
rank1_tensors = [tensor for tensor in tensors if tensor.ndimension() <= 1]
rank1_tensors_memory = ([tensor.view(-1) for tensor in rank1_tensors])
if rank1_tensors
else torch.tensor([], device=device)

# Step II: Handle high-rank tensors.
# Allocate contiguous memory for Ps and Qs to allreduce compressed high-rank tensors efficiently.
high_rank_tensors = [
tensor.view(tensor.shape[0], -1)
for tensor in tensors
if tensor.ndimension() > 1
total_Ps_size = 0
ps_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
total_Qs_size = 0
qs_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
for tensor in high_rank_tensors:
n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
total_Ps_size += n * matrix_approximation_rank
total_Qs_size += m * matrix_approximation_rank
ps_memory = torch.empty(total_Ps_size, device=device, dtype=dtype)
qs_memory = torch.empty(total_Qs_size, device=device, dtype=dtype)

# Create Ps and Qs that point to the allocated memory.
ps = []
qs = []
p_idx = 0
q_idx = 0
for tensor in high_rank_tensors:
n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
ps_memory[p_idx : p_idx + n * matrix_approximation_rank].view(
n, matrix_approximation_rank
qs_memory[q_idx : q_idx + m * matrix_approximation_rank].view(
m, matrix_approximation_rank
p_idx += n * matrix_approximation_rank
q_idx += m * matrix_approximation_rank

# Initialize and then orthogonalize Qs.
with torch.random.fork_rng(devices=[]):
# Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
# The seed makes sure that the initial random values are the same across all the DDP replicas.
# Such seed should differ at every step.
# Since it is very slow to fork RNG state across all the CUDA devices,
# only fork on CPU and then move the generated tensor to the CUDA device.
for q in qs: = torch.randn(

# Compute Ps.
for tensor, q, p in zip(high_rank_tensors, qs, ps):
torch.matmul(tensor, q, out=p)

# This allreduce is only applied to rank-1 tensors,
# so it should have been kicked off before the above computation on the high-rank tensors to hide more communication costs.
# However, this somehow requires a separate future chain at this time.
allreduce_contiguous_rank1_tensors_fut = dist.all_reduce(
rank1_tensors_memory, group=group_to_use, async_op=True

def unpack_rank1_tensors_and_allreduce_ps(fut):
rank1_tensors_memory = fut.value()[0].div_(world_size)
idx = 0
for tensor in rank1_tensors:
tensor.copy_(rank1_tensors_memory[idx : idx + tensor.shape[0]])
idx += tensor.shape[0]

# Since these Ps will be orthogonized later, no need to divide them by world size.
return [
dist.all_reduce(ps_memory, group=group_to_use, async_op=True)

def compute_qs(fut):
ps_memory = fut.wait()[0]
for p in ps:

# Compute Qs.
for tensor, p, q in zip(high_rank_tensors, ps, qs):
torch.matmul(tensor.t(), p, out=q)

# Allreduce Qs.
return [
dist.all_reduce(qs_memory, group=group_to_use, async_op=True)

def decompress(fut):
qs_memory = fut.wait()[0].div_(world_size)

for p, q, tensor in zip(ps, qs, high_rank_tensors):
torch.matmul(p, q.t(), out=tensor)
assert not torch.any(torch.isnan(tensor))
return [input_tensor]

return (

def batched_powerSGD_hook(
state: PowerSGDState,
) -> torch.futures.Future:
This DDP communication hook implements a simplified PowerSGD gradient compression
algorithm described in
Once gradient tensors are aggregated across all workers, this hook applies
compression to the flattened input tensor that batches per-parameter tensors as follows:
compression as follows:
1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
2) Creates two low-rank tensors P and Q for decomposing M,
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
Expand All @@ -288,7 +105,7 @@ def batched_powerSGD_hook(
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
process_group = state.process_group
group_to_use = process_group if process_group is not None else
Expand Down
47 changes: 0 additions & 47 deletions torch/testing/_internal/distributed/
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import torch
import torch.cuda
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
from import DistributedSampler
from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars
import torch.nn as nn
Expand Down Expand Up @@ -2820,52 +2819,6 @@ def test_DistributedDataParallel_non_default_stream(self):
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",

BACKEND != "nccl",
"Only NCCL backend support DistributedDataParallel",
def test_DistributedDataParallel_powerSGD_ddp_comm_hook(self):
stream = torch.cuda.Stream(self.rank)
rank = self.rank
net = torch.nn.parallel.DistributedDataParallel(
torch.nn.Linear(1, 5).to(rank), device_ids=[rank]
process_group = torch.distributed.new_group([0, 1])
state = powerSGD.PowerSGDState(
process_group=process_group, matrix_approximation_rank=1
net.register_comm_hook(state=state, hook=powerSGD.powerSGD_hook)
# NOTE: batched_powerSGD_hook cannot pass the following test, because it has a lower accuracy.
for i in range(1000):
# Clear gradients manually.
grad = net.module.weight.grad
if grad is not None:
# Forward + BW
batch = torch.tensor([rank]).float().cuda(rank)
loss = net(batch).sum()
# For each worker, the gradient on the weight should be worker_rank.
grad = net.module.weight.grad
avg = grad.clone()
# All-reducing the gradient averages should give us the gradient
# average. If not, then one of the workers has not correctly
# written back the averaged gradient before this all-reduce call.
world_size = int(os.environ["WORLD_SIZE"])
expected_grad = sum(i for i in range(world_size)) / world_size
avg[0, 0],
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",

@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
"Only Nccl & Gloo backend support DistributedDataParallel")
Expand Down

0 comments on commit ad9923e

Please sign in to comment.