Skip to content

Commit

Permalink
[Gradient Compression] Simplify the implementation of error feedback …
Browse files Browse the repository at this point in the history
…and warm-start (#50981)

Summary:
Pull Request resolved: #50981

Since vanilla allreduce will to be applied in the first few iterations, bucket rebuilding process will not affect caching per-variable tensors.

Previously the cached tensors used for error feedback and warm-up need to be rebuilt later, because their corresponding input tensors' shape will be changed after the bucket rebuild process.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 120617971

Test Plan: real run

Reviewed By: rohan-varma

Differential Revision: D26034418

fbshipit-source-id: e8744431c7f3142d75b77b60110e6861c2ff5c14
  • Loading branch information
Yi Wang authored and facebook-github-bot committed Jan 29, 2021
1 parent 00d4ec8 commit b619d37
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Expand Up @@ -78,6 +78,7 @@ def __init__(
# However, this means that the shape of input bucketized tensors is subject to change,
# which will complicate the implementations of error feedback and warm-up.
# Running vanilla allreduce in the first few iterations can avoid this complexity.
assert start_powerSGD_iter >= 1
self.start_powerSGD_iter = start_powerSGD_iter
# Error feedback is usually crucial for both for convergence and generalization,
# because PowerSGD is a biased compressor,
Expand Down Expand Up @@ -152,6 +153,10 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
3.6) Allreduces Qs as a batch;
3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
one left multiplication and one right multiplication.
For warm-start, can take one such step at a time, and alternate between them.
Expand Down Expand Up @@ -197,13 +202,7 @@ def div_callback(fut):
input_tensor_cp = None
total_length = input_tensor.shape[0]
if state.use_error_feedback:
# The buckets can be rebuilt during training.
# In this case, the error tensor shape will not be aligned with the input tensor,
# and the error will be re-initialized as zeros.
if (
bucket_index in state.error_dict
and state.error_dict[bucket_index].shape[0] == total_length
):
if bucket_index in state.error_dict:
input_tensor.add_(state.error_dict[bucket_index])
else:
logging.info(
Expand Down Expand Up @@ -250,15 +249,9 @@ def div_callback(fut):
total_Ps_size += n * matrix_approximation_rank
total_Qs_size += m * matrix_approximation_rank
# If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible.
# The memory spaces of Ps and Qs need to be (re)allocated at the beginning,
# as well as later whenever the buckets are rebuilt during training.
# The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied.
need_randomize_qs = False
if (
not state.warm_start
or bucket_index not in state.p_memory_dict
or state.p_memory_dict[bucket_index].shape[0] != total_Ps_size
or state.q_memory_dict[bucket_index].shape[0] != total_Qs_size
):
if not state.warm_start or bucket_index not in state.p_memory_dict:
need_randomize_qs = True
# If warm-start is disabled, low-rank tensors will be initialized at every step.
# Only log this if warm-start to avoid spamming.
Expand Down Expand Up @@ -297,7 +290,7 @@ def div_callback(fut):
q_idx += m * matrix_approximation_rank

# If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values.
# The exceptions are the first time and when the buckets are rebuilt.
# The exception is the first iteration when PowerSGD is applied.
if not need_randomize_qs:
for q in qs:
_orthogonalize(q)
Expand Down

0 comments on commit b619d37

Please sign in to comment.