Skip to content

Commit

Permalink
[Gradient Compression] Allow BatchedPowerSGD to run vanilla allreduce…
Browse files Browse the repository at this point in the history
… for the first K iterations

Pull Request resolved: #51270

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 120617938

Differential Revision: [D26077709](https://our.internmc.facebook.com/intern/diff/D26077709/)
  • Loading branch information
wayi committed Jan 29, 2021
1 parent b09cd5c commit 86f2798
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
7) Computes M, which is approximately equal to PQ^T.
8) Truncates the input tensor to the original length.
Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
one left multiplication and one right multiplication.
For warm-start, can take one such step at a time, and alternate between them.
Expand All @@ -419,6 +423,13 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]

# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
state.maybe_increase_iter(bucket)
return default.allreduce_fut(group_to_use, input_tensor)

# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
total_length = input_tensor.shape[0]

Expand All @@ -432,13 +443,7 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
bucket_index = bucket.get_index()
input_tensor_cp = None
if state.use_error_feedback:
# The buckets can be rebuilt during training.
# In this case, the error tensor shape will not be aligned with the input tensor,
# and the error will be re-initialized as zeros.
if (
bucket_index in state.error_dict
and state.error_dict[bucket_index].shape[0] == padded_total_length
):
if bucket_index in state.error_dict:
input_tensor.add_(state.error_dict[bucket_index])
else:
logging.info(
Expand All @@ -457,14 +462,8 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
matrix = input_tensor.view(square_side_length, square_side_length)

# Reuse P and Q from the previous iteration if possible.
# The memory spaces of P and Q need to be (re)allocated at the beginning,
# as well as later whenever the buckets are rebuilt during training.
if (
not state.warm_start
or bucket_index not in state.p_memory_dict
or state.p_memory_dict[bucket_index].shape
!= (square_side_length, state.matrix_approximation_rank)
):
# The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied.
if not state.warm_start or bucket_index not in state.p_memory_dict:
# If warm-start is disabled, low-rank tensors will be initialized at every step.
# Only log this if warm-start to avoid spamming.
if state.warm_start:
Expand Down Expand Up @@ -549,6 +548,9 @@ def decompress(fut):
state.p_memory_dict.clear()
state.q_memory_dict.clear()
ret = input_tensor.resize_(total_length)

state.maybe_increase_iter(bucket)

return [ret]

return allreduce_p_fut.then(compute_q).then(decompress)

0 comments on commit 86f2798

Please sign in to comment.