diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 32ef2f7c35c2..c9b7db1d7bf7 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -250,14 +250,11 @@ def div_callback(fut): total_Ps_size += n * matrix_approximation_rank total_Qs_size += m * matrix_approximation_rank # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible. - # The memory spaces of Ps and Qs need to be (re)allocated at the beginning, - # as well as later whenever the buckets are rebuilt during training. + # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied. need_randomize_qs = False if ( not state.warm_start or bucket_index not in state.p_memory_dict - or state.p_memory_dict[bucket_index].shape[0] != total_Ps_size - or state.q_memory_dict[bucket_index].shape[0] != total_Qs_size ): need_randomize_qs = True # If warm-start is disabled, low-rank tensors will be initialized at every step. @@ -297,7 +294,7 @@ def div_callback(fut): q_idx += m * matrix_approximation_rank # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values. - # The exceptions are the first time and when the buckets are rebuilt. + # The exception is the first iteration when PowerSGD is applied. if not need_randomize_qs: for q in qs: _orthogonalize(q)