Skip to content

Commit

Permalink
Update on "[Gradient Compression] Refactor default_hooks.py and power…
Browse files Browse the repository at this point in the history
…SGD_hook.py by creating a util function that make a vanilla allreduce future"


Address #50973 (comment)

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26070147](https://our.internmc.facebook.com/intern/diff/D26070147/)

[ghstack-poisoned]
  • Loading branch information
wayi committed Jan 29, 2021
2 parents 0f57dfa + 4f79850 commit be1af96
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
# However, this means that the shape of input bucketized tensors is subject to change,
# which will complicate the implementations of error feedback and warm-up.
# Running vanilla allreduce in the first few iterations can avoid this complexity.
assert start_powerSGD_iter >= 1
self.start_powerSGD_iter = start_powerSGD_iter
# Error feedback is usually crucial for both for convergence and generalization,
# because PowerSGD is a biased compressor,
Expand Down Expand Up @@ -153,6 +154,10 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
3.6) Allreduces Qs as a batch;
3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
one left multiplication and one right multiplication.
For warm-start, can take one such step at a time, and alternate between them.
Expand Down Expand Up @@ -191,13 +196,7 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
input_tensor_cp = None
total_length = input_tensor.shape[0]
if state.use_error_feedback:
# The buckets can be rebuilt during training.
# In this case, the error tensor shape will not be aligned with the input tensor,
# and the error will be re-initialized as zeros.
if (
bucket_index in state.error_dict
and state.error_dict[bucket_index].shape[0] == total_length
):
if bucket_index in state.error_dict:
input_tensor.add_(state.error_dict[bucket_index])
else:
logging.info(
Expand Down Expand Up @@ -246,10 +245,7 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
# If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible.
# The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied.
need_randomize_qs = False
if (
not state.warm_start
or bucket_index not in state.p_memory_dict
):
if not state.warm_start or bucket_index not in state.p_memory_dict:
need_randomize_qs = True
# If warm-start is disabled, low-rank tensors will be initialized at every step.
# Only log this if warm-start to avoid spamming.
Expand Down Expand Up @@ -365,7 +361,7 @@ def decompress(fut):

if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[bucket_index].copy_(input_tensor_cp - input_tensor)
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
if not state.warm_start:
state.p_memory_dict.clear()
state.q_memory_dict.clear()
Expand Down Expand Up @@ -546,7 +542,7 @@ def decompress(fut):

if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[bucket_index].copy_(input_tensor_cp - input_tensor)
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
if torch.cuda.is_available():
torch.cuda.synchronize(device)
if not state.warm_start:
Expand Down

0 comments on commit be1af96

Please sign in to comment.