Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Allow PowerSGD to run vallina allreduce for the first K iterations #50973

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
70 changes: 67 additions & 3 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,52 @@ def _orthogonalize(matrix, epsilon=1e-8):
class PowerSGDState(object):
__slots__ = [
"process_group",
# The two fields below are the configs that usually need to be tuned by the user.
"matrix_approximation_rank",
"start_powerSGD_iter",
# The two fields below are the configs that usually need to be turned on for performance.
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
"use_error_feedback",
"warm_start",
# The fields below are not configs.
"rng",
"error_dict",
"p_memory_dict",
"q_memory_dict",
"iter",
]

def __init__(
self,
process_group,
matrix_approximation_rank=1,
start_powerSGD_iter=10,
use_error_feedback=True,
warm_start=True,
random_seed=0,
):
self.process_group = process_group
# The low rank for matrix approximation.
# Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf.
# The low rank for matrix approximation controls the size of compressed low-rank tensors,
# which determines the computation ratio.
# Typically only a small value 1-4 is used.
# For some NLP tasks (as shown in Appendix D of the original paper
# https://arxiv.org/pdf/1905.13727.pdf, the rank value has been increased to 32.
# A high rank value will increase the computation costs of compression exponentially.
# A good choice depends on how much extra computation can be hidden by the dominating communication costs.
self.matrix_approximation_rank = matrix_approximation_rank
# This defers PowerSGD compression util step 'start_powerSGD_iter',
# and vanilla allreduce runs before step 'start_powerSGD_iter'.
# This hybrid scheme of vanilla allreduce + PowerSGD can have two advantages:
# 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
# even if the matrix approximation rank is increased to a large value.
# To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
# (or a more convervative compression such as FP16 compression) with PowerSGD.
# 2) There is an internal optimization of rebuilding buckets process in DDP,
# in order to save the memory space.
# This step takes place after the first iteration.
# However, this means that the shape of input bucketized tensors is subject to change,
# which will complicate the implementations of error feedback and warm-up.
# Running vanilla allreduce in the first few iterations can avoid this complexity.
self.start_powerSGD_iter = start_powerSGD_iter
# Error feedback is usually crucial for both for convergence and generalization,
# because PowerSGD is a biased compressor,
# i.e., compressing and decompressing a random gradient does not yield the original in expectation.
Expand All @@ -80,6 +105,29 @@ def __init__(
self.error_dict = {}
self.p_memory_dict = {}
self.q_memory_dict = {}
# Iteration/step in the training loop.
self.iter = 0

logging.info(
"PowerSGD config: matrix_approximation_rank = {}; "
"start_powerSGD_iter = {}; use_error_feedback = {}; warm_start = {}.".format(
self.matrix_approximation_rank,
self.start_powerSGD_iter,
self.use_error_feedback,
self.warm_start,
)
)

def maybe_increase_iter(self, bucket):
# Since bucket 0 is the last bucket to allreduce in an iteration.
# Only increase `iter` when bucket 0 is processed.
if bucket.get_index() == 0:
self.iter += 1

if self.iter == self.start_powerSGD_iter:
logging.info(
"Starting to apply PowerSGD after {} iterations.".format(self.iter)
)


def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
Expand Down Expand Up @@ -118,7 +166,7 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
Future handler of the communication, which updates the gradients in place.

Example::
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
"""
process_group = state.process_group
Expand All @@ -127,6 +175,20 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]

# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
fut = dist.all_reduce(
input_tensor, group=group_to_use, async_op=True
).get_future()

def div_callback(fut):
return [fut.value()[0].div_(world_size)]

state.maybe_increase_iter(bucket)
return fut.then(div_callback)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
dtype = input_tensor.dtype

Expand Down Expand Up @@ -317,6 +379,8 @@ def decompress(fut):
state.p_memory_dict.clear()
state.q_memory_dict.clear()

state.maybe_increase_iter(bucket)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

return [input_tensor]

return (
Expand Down