Skip to content

Commit

Permalink
[Gradient Compression] Replace the key of error_dict in PowerSGD stat…
Browse files Browse the repository at this point in the history
…e with bucket index (#48867)

Summary:
Pull Request resolved: #48867

Previously the key of error_dict is the hashcode of tensor. Now replaced with bucket index.

Bucket index can have a few advantages over the hashcode of tensor.
1) Error dict in the state never removes any key. If the bucket rebuild process occurs frequently, the size of error dict can increase. For now, such rebuild process is infrequent, so it is probably fine.

2) Integer index has a better readability than hashcode, and it can facilitate debugging.
If the user wants to debug the tensor values, usually only a specific bucket needs to be targeted. It's easy to specify such condition (e..g, bucket_index = 0), but it's hard to specify a hashcode in advance, as it can only be determined at runtime.

Note that sometimes the buckets can be rebuilt in the forward pass. In this case, the shape of the bucket with the same index will not be consistent with the one in the previous iteration, and hence the error tensor will be re--initialized as a zero tensor of the new shape. Therefore, `and state.error_dict[bucket_index].shape[0] == padded_total_length` is added to the condition of applying the local error from the previous iteration.

Deleted the arg type of `dist._GradBucket` in powerSGD_hook.py, because somehow test_run_mypy - TestTypeHints failed:
AssertionError: mypy failed: torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py:128: error: "_GradBucket" has no attribute "get_index"  [attr-defined]

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 117951402

Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl

Reviewed By: rohan-varma

Differential Revision: D25346347

fbshipit-source-id: 8348aa103002ec1c69e3ae759504b431140b3b0d
  • Loading branch information
Yi Wang authored and facebook-github-bot committed Dec 6, 2020
1 parent 2e600fe commit 17f53bf
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@@ -1,3 +1,4 @@
import logging
import math

import numpy as np
Expand Down Expand Up @@ -63,16 +64,13 @@ def __init__(
# there will be differences between the gradients that are never synchronized.
self.rng = np.random.RandomState(random_seed)
# Since there is only a single state instance for all the input buckets,
# need to maintain a dictionary that maps each bucket to the local error.
# TODO(wayi): Currently the key is the (hashcode of) input tensor, which may change across steps,
# since the bucket can be rebuilt in the forward pass (to save peak memory usage).
# Need to add an index field to the input bucket of comm hook.
# need to maintain a dictionary that maps each bucket index to the local error.
self.error_dict = {}


def powerSGD_hook(
state: PowerSGDState,
bucket: dist._GradBucket,
bucket,
) -> torch.futures.Future:
"""
This DDP communication hook implements a simplified PowerSGD gradient compression
Expand Down Expand Up @@ -127,11 +125,26 @@ def powerSGD_hook(
input_tensor[total_length:padded_total_length].fill_(0)

# Incorporate the error from the previous state into the gradients.
bucket_index = bucket.get_index()
if state.use_error_feedback:
if input_tensor in state.error_dict:
input_tensor.add_(state.error_dict[input_tensor])
# The buckets can be rebuilt during training.
# In this case, the error tensor shape will not be aligned with the input tensor,
# and the error will be re-initialized as zeros.
if (
bucket_index in state.error_dict
and state.error_dict[bucket_index].shape[0] == padded_total_length
):
input_tensor.add_(state.error_dict[bucket_index])
else:
state.error_dict[input_tensor] = torch.zeros(padded_total_length, device=device)
logging.info(
"A zero tensor of length {} that represents local error is created.".format(
padded_total_length
)
)
state.error_dict[bucket_index] = torch.zeros(
padded_total_length, device=device
)

# Keep a copy of the input tensor,
# so that we can compute the local error caused by compression later,
# by comparing this copy and the input tensor updated after decompression.
Expand Down Expand Up @@ -181,7 +194,7 @@ def decompress(fut):

if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[input_tensor] = input_tensor_cp - input_tensor
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
ret = input_tensor.resize_(total_length)
return [ret]

Expand Down

0 comments on commit 17f53bf

Please sign in to comment.