Skip to content

Commit

Permalink
[Gradient Compression] Add error feedback to layerwise PowerSGD (#49418)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49418

Add error feedback to the original implementation of PowerSGD.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 118670930

Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl

buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_DistributedDataParallel_powerSGD_ddp_comm_hook

Reviewed By: rohan-varma

Differential Revision: D25555538

fbshipit-source-id: c01145cc9acf574a4c6aa337dbbba0ba7d9350b2
  • Loading branch information
Yi Wang authored and facebook-github-bot committed Dec 21, 2020
1 parent 5c25f8f commit 342bfd8
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Expand Up @@ -122,6 +122,33 @@ def powerSGD_hook(
input_tensor = bucket.get_tensors()[0]
device = input_tensor.device
dtype = input_tensor.dtype

# Incorporate the error from the previous state into the gradients.
bucket_index = bucket.get_index()
input_tensor_cp = None
total_length = input_tensor.shape[0]
if state.use_error_feedback:
# The buckets can be rebuilt during training.
# In this case, the error tensor shape will not be aligned with the input tensor,
# and the error will be re-initialized as zeros.
if (
bucket_index in state.error_dict
and state.error_dict[bucket_index].shape[0] == total_length
):
input_tensor.add_(state.error_dict[bucket_index])
else:
logging.info(
"A zero tensor of length {} that represents local error is created.".format(
total_length
)
)
state.error_dict[bucket_index] = torch.zeros(total_length, device=device)

# Keep a copy of the input tensor,
# so that we can compute the local error caused by compression later,
# by comparing this copy and the input tensor updated after decompression.
input_tensor_cp = torch.clone(input_tensor).detach()

# Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
tensors = [
input_tensor[offset : offset + length].view(sizes)
Expand Down Expand Up @@ -242,6 +269,10 @@ def decompress(fut):
for p, q, tensor in zip(ps, qs, high_rank_tensors):
torch.matmul(p, q.t(), out=tensor)
assert not torch.any(torch.isnan(tensor))

if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
return [input_tensor]

return (
Expand Down

0 comments on commit 342bfd8

Please sign in to comment.