Skip to content

Commit

Permalink
[Gradient Compression] Allow BatchedPowerSGD to run vanilla allreduce…
Browse files Browse the repository at this point in the history
… for the first K iterations (#51270)

Summary:
Pull Request resolved: #51270

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 120725858

Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl

baseline: f248001754
batched PowerSGD: f246960752

The training time was reduced from 54m48s to 30m33s, and the accuracy is approximately the same: 44.21 vs 44.35

Reviewed By: rohan-varma

Differential Revision: D26077709

fbshipit-source-id: 6afeefad7a3fbdd7da2cbffb56dfbad855a96cb5
  • Loading branch information
Yi Wang authored and facebook-github-bot committed Feb 1, 2021
1 parent 718e4b1 commit c080780
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
7) Computes M, which is approximately equal to PQ^T.
8) Truncates the input tensor to the original length.
Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
one left multiplication and one right multiplication.
For warm-start, can take one such step at a time, and alternate between them.
Expand All @@ -422,6 +426,13 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]

# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
state.maybe_increase_iter(bucket)
return default._allreduce_fut(group_to_use, input_tensor)

# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
total_length = input_tensor.shape[0]

Expand All @@ -435,13 +446,7 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
bucket_index = bucket.get_index()
input_tensor_cp = None
if state.use_error_feedback:
# The buckets can be rebuilt during training.
# In this case, the error tensor shape will not be aligned with the input tensor,
# and the error will be re-initialized as zeros.
if (
bucket_index in state.error_dict
and state.error_dict[bucket_index].shape[0] == padded_total_length
):
if bucket_index in state.error_dict:
input_tensor.add_(state.error_dict[bucket_index])
else:
logging.info(
Expand All @@ -460,14 +465,8 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
matrix = input_tensor.view(square_side_length, square_side_length)

# Reuse P and Q from the previous iteration if possible.
# The memory spaces of P and Q need to be (re)allocated at the beginning,
# as well as later whenever the buckets are rebuilt during training.
if (
not state.warm_start
or bucket_index not in state.p_memory_dict
or state.p_memory_dict[bucket_index].shape
!= (square_side_length, state.matrix_approximation_rank)
):
# The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied.
if not state.warm_start or bucket_index not in state.p_memory_dict:
# If warm-start is disabled, low-rank tensors will be initialized at every step.
# Only log this if warm-start to avoid spamming.
if state.warm_start:
Expand Down Expand Up @@ -552,6 +551,9 @@ def decompress(fut):
state.p_memory_dict.clear()
state.q_memory_dict.clear()
ret = input_tensor.resize_(total_length)

state.maybe_increase_iter(bucket)

return [ret]

return allreduce_p_fut.then(compute_q).then(decompress)

0 comments on commit c080780

Please sign in to comment.