Skip to content

Commit

Permalink
[Gradient Compression] Error feedback for PowerSGD (still need to fix…
Browse files Browse the repository at this point in the history
… the key in error_dict)

Pull Request resolved: #48670

Support an optional error feedback for PowerSGD -- storing the difference (i.e., the local error caused by compression) between the input gradient (adjusted by the existing error) and the gradient after decompression, and reinserting it at the next iteration.

Still need to add an index field to GradBucket as the key of error_dict. This is because the current key, input tensor of the bucket, can change across steps, as the buckets may be rebuilt in forward pass in order to save peak memory usage.

This is halfway of error feedback. Plan to add the new index field in a separate PR.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 117636492

Differential Revision: [D25240290](https://our.internmc.facebook.com/intern/diff/D25240290/)
  • Loading branch information
wayi committed Dec 2, 2020
1 parent 44016e6 commit 575d903
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
8 changes: 7 additions & 1 deletion torch/distributed/algorithms/ddp_comm_hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ def _ddp_comm_hook_wrapper(comm_hook, model, state):


def _powerSGD_comm_hook_wrapper(
comm_hook, model, state, matrix_approximation_rank, random_seed=0
comm_hook,
model,
state,
matrix_approximation_rank,
use_error_feedback=True,
random_seed=0,
):
"""
To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
Expand All @@ -25,6 +30,7 @@ def _powerSGD_comm_hook_wrapper(
powerSGD_state = powerSGD.PowerSGDState(
process_group=state,
matrix_approximation_rank=matrix_approximation_rank,
use_error_feedback=use_error_feedback,
random_seed=random_seed,
)
model.register_comm_hook(powerSGD_state, comm_hook)
Expand Down
47 changes: 44 additions & 3 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,44 @@ def _orthogonalize(matrix, epsilon=1e-8):


class PowerSGDState(object):
__slots__ = ["process_group", "matrix_approximation_rank", "rng"]

def __init__(self, process_group, matrix_approximation_rank=1, random_seed=0):
__slots__ = [
"process_group",
"matrix_approximation_rank",
"use_error_feedback",
"rng",
"error_dict",
]

def __init__(
self,
process_group,
matrix_approximation_rank=1,
use_error_feedback=True,
random_seed=0,
):
self.process_group = process_group
self.matrix_approximation_rank = matrix_approximation_rank
# Error feedback is usually crucial for both for convergence and generalization,
# because PowerSGD is a biased compressor,
# i.e., compressing and decompressing a random gradient does not yield the original in expectation.
# This mechanism requires a temporary copy of the input gradients,
# so it increases the peak memory consumption by the size of gradient tensor.
# However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank),
# sometimes it is possible to converge to the optima without error feedback.
# See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf
self.use_error_feedback = use_error_feedback
# The purpose of this RNG is to generate different random seeds for initializing Q across iterations,
# but in the same order for all the DDP replicas.
# Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps.
# If the same random projection is used,
# there will be differences between the gradients that are never synchronized.
self.rng = np.random.RandomState(random_seed)
# Since there is only a single state instance for all the input buckets,
# need to maintain a dictionary that maps each bucket to the local error.
# TODO(wayi): Currently the key is the (hashcode of) input tensor, which may change across steps,
# since the bucket can be rebuilt in the forward pass (to save peak memory usage).
# Need to add an index field to the input bucket of comm hook.
self.error_dict = {}


def powerSGD_hook(
Expand Down Expand Up @@ -98,6 +125,17 @@ def powerSGD_hook(
padded_total_length = square_side_length ** 2
input_tensor.resize_(padded_total_length)
input_tensor[total_length:padded_total_length].fill_(0)

# Incorporate the error from the previous state into the gradients.
if state.use_error_feedback:
if input_tensor in state.error_dict:
input_tensor.add_(state.error_dict[input_tensor])
else:
state.error_dict[input_tensor] = torch.zeros(padded_total_length, device=device)
# Keep a copy of the input tensor,
# so that we can compute the local error caused by compression later,
# by comparing this copy and the input tensor updated after decompression.
input_tensor_cp = torch.clone(input_tensor).detach()
matrix = input_tensor.view(square_side_length, square_side_length)

def create_low_rank_tensor(fill_random_values, rng):
Expand Down Expand Up @@ -141,6 +179,9 @@ def decompress(fut):
q = fut.value()[0].div_(world_size)
torch.matmul(p, q.t(), out=matrix)

if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[input_tensor] = input_tensor_cp - input_tensor
ret = input_tensor.resize_(total_length)
return [ret]

Expand Down

0 comments on commit 575d903

Please sign in to comment.