Skip to content

Commit

Permalink
[Gradient Compression] Allow PowerSGD to run vallina allreduce for th…
Browse files Browse the repository at this point in the history
…e first K iterations

This can extend the original PowerSGD method to a hybrid approach: vanilla allreduce + PowerSGD. This can help further improve the accuracy, at the cost of a lower speedup.

Also add more comments on the fields in `PowerSGDState`.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26031478](https://our.internmc.facebook.com/intern/diff/D26031478/)

ghstack-source-id: 120245539
Pull Request resolved: #50973
  • Loading branch information
wayi committed Jan 23, 2021
1 parent 5c1c858 commit bd7b1bb
Showing 1 changed file with 66 additions and 3 deletions.
69 changes: 66 additions & 3 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,52 @@ def _orthogonalize(matrix, epsilon=1e-8):
class PowerSGDState(object):
__slots__ = [
"process_group",
# The two fields below are the configs that usually need to be tuned by the user.
"matrix_approximation_rank",
"start_powerSGD_iter",
# The two fields below are the configs that usually need to be turned on for performance.
"use_error_feedback",
"warm_start",
# The fields below are not configs.
"rng",
"error_dict",
"p_memory_dict",
"q_memory_dict",
"iter",
]

def __init__(
self,
process_group,
matrix_approximation_rank=1,
start_powerSGD_iter=10,
use_error_feedback=True,
warm_start=True,
random_seed=0,
):
self.process_group = process_group
# The low rank for matrix approximation.
# Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf.
# The low rank for matrix approximation controls the size of compressed low-rank tensors,
# which determines the computation ratio.
# Typically only a small value 1-4 is used.
# For some NLP tasks (as shown in Appendix D of the original paper
# https://arxiv.org/pdf/1905.13727.pdf, the rank value has been increased to 32.
# A high rank value will increase the computation costs of compression exponentially.
# A good choice depends on how much extra computation can be hidden by the dominating communication costs.
self.matrix_approximation_rank = matrix_approximation_rank
# This defers PowerSGD compression util step 'start_powerSGD_iter',
# and vanilla allreduce runs before step 'start_powerSGD_iter'.
# This hybrid scheme of vanilla allreduce + PowerSGD can have two advantages:
# 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
# even if the matrix approximation rank is increased to a large value.
# To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
# (or a more convervative compression such as FP16 compression) with PowerSGD.
# 2) There is an internal optimization of rebuilding buckets process in DDP,
# in order to save the memory space.
# This step takes place after the first iteration.
# However, this means that the shape of input bucketized tensors is subject to change,
# which will complicate the implementations of error feedback and warm-up.
# Running vanilla allreduce in the first few iterations can avoid this complexity.
self.start_powerSGD_iter = start_powerSGD_iter
# Error feedback is usually crucial for both for convergence and generalization,
# because PowerSGD is a biased compressor,
# i.e., compressing and decompressing a random gradient does not yield the original in expectation.
Expand All @@ -80,6 +105,28 @@ def __init__(
self.error_dict = {}
self.p_memory_dict = {}
self.q_memory_dict = {}
# Iteration/step in the training loop.
self.iter = 0

logging.info(
"PowerSGD config: matrix_approximation_rank = {}; start_powerSGD_iter = {}; use_error_feedback = {}; warm_start = {}.".format(
self.matrix_approximation_rank,
self.start_powerSGD_iter,
self.use_error_feedback,
self.warm_start,
)
)

def maybe_increase_iter(self, bucket):
# Since bucket 0 is the last bucket to allreduce in an iteration.
# Only increase `iter` when bucket 0 is processed.
if bucket.get_index() == 0:
self.iter += 1

if self.iter == self.start_powerSGD_iter:
logging.info(
"Starting to apply PowerSGD after {} iterations.".format(self.iter)
)


def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
Expand Down Expand Up @@ -118,7 +165,7 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
Future handler of the communication, which updates the gradients in place.
Example::
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
"""
process_group = state.process_group
Expand All @@ -127,6 +174,20 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]

# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
fut = dist.all_reduce(
input_tensor, group=group_to_use, async_op=True
).get_future()

def div_callback(fut):
return [fut.value()[0].div_(world_size)]

state.maybe_increase_iter(bucket)
return fut.then(div_callback)

# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
dtype = input_tensor.dtype

Expand Down Expand Up @@ -317,6 +378,8 @@ def decompress(fut):
state.p_memory_dict.clear()
state.q_memory_dict.clear()

state.maybe_increase_iter(bucket)

return [input_tensor]

return (
Expand Down

0 comments on commit bd7b1bb

Please sign in to comment.