From b619d37bb47c0caa00bad3bb8b8e654172c887e4 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Thu, 28 Jan 2021 18:57:12 -0800 Subject: [PATCH] [Gradient Compression] Simplify the implementation of error feedback and warm-start (#50981) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50981 Since vanilla allreduce will to be applied in the first few iterations, bucket rebuilding process will not affect caching per-variable tensors. Previously the cached tensors used for error feedback and warm-up need to be rebuilt later, because their corresponding input tensors' shape will be changed after the bucket rebuild process. Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202 ghstack-source-id: 120617971 Test Plan: real run Reviewed By: rohan-varma Differential Revision: D26034418 fbshipit-source-id: e8744431c7f3142d75b77b60110e6861c2ff5c14 --- .../ddp_comm_hooks/powerSGD_hook.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 32ef2f7c35c2..fcc42f82f692 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -78,6 +78,7 @@ def __init__( # However, this means that the shape of input bucketized tensors is subject to change, # which will complicate the implementations of error feedback and warm-up. # Running vanilla allreduce in the first few iterations can avoid this complexity. + assert start_powerSGD_iter >= 1 self.start_powerSGD_iter = start_powerSGD_iter # Error feedback is usually crucial for both for convergence and generalization, # because PowerSGD is a biased compressor, @@ -152,6 +153,10 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: 3.6) Allreduces Qs as a batch; 3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T. + Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations. + This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy, + but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers. + TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration -- one left multiplication and one right multiplication. For warm-start, can take one such step at a time, and alternate between them. @@ -197,13 +202,7 @@ def div_callback(fut): input_tensor_cp = None total_length = input_tensor.shape[0] if state.use_error_feedback: - # The buckets can be rebuilt during training. - # In this case, the error tensor shape will not be aligned with the input tensor, - # and the error will be re-initialized as zeros. - if ( - bucket_index in state.error_dict - and state.error_dict[bucket_index].shape[0] == total_length - ): + if bucket_index in state.error_dict: input_tensor.add_(state.error_dict[bucket_index]) else: logging.info( @@ -250,15 +249,9 @@ def div_callback(fut): total_Ps_size += n * matrix_approximation_rank total_Qs_size += m * matrix_approximation_rank # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible. - # The memory spaces of Ps and Qs need to be (re)allocated at the beginning, - # as well as later whenever the buckets are rebuilt during training. + # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied. need_randomize_qs = False - if ( - not state.warm_start - or bucket_index not in state.p_memory_dict - or state.p_memory_dict[bucket_index].shape[0] != total_Ps_size - or state.q_memory_dict[bucket_index].shape[0] != total_Qs_size - ): + if not state.warm_start or bucket_index not in state.p_memory_dict: need_randomize_qs = True # If warm-start is disabled, low-rank tensors will be initialized at every step. # Only log this if warm-start to avoid spamming. @@ -297,7 +290,7 @@ def div_callback(fut): q_idx += m * matrix_approximation_rank # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values. - # The exceptions are the first time and when the buckets are rebuilt. + # The exception is the first iteration when PowerSGD is applied. if not need_randomize_qs: for q in qs: _orthogonalize(q)