Skip to content

Commit

Permalink
[Gradient Compression] Simplify the implementation of warm-start
Browse files Browse the repository at this point in the history
Since PowerSGD will to be applied in the first few iterations, bucket rebuilding process will not affect caching per-variable tensors.

Previously the cached tensors used for error feedback need to be rebuilt later, because their corresponding input tensors' shape wil be changed after the bucket rebuild process.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26034418](https://our.internmc.facebook.com/intern/diff/D26034418/)

ghstack-source-id: 120257256
Pull Request resolved: #50981
  • Loading branch information
wayi committed Jan 23, 2021
1 parent 127e382 commit 59c8077
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Expand Up @@ -250,14 +250,11 @@ def div_callback(fut):
total_Ps_size += n * matrix_approximation_rank
total_Qs_size += m * matrix_approximation_rank
# If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible.
# The memory spaces of Ps and Qs need to be (re)allocated at the beginning,
# as well as later whenever the buckets are rebuilt during training.
# The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied.
need_randomize_qs = False
if (
not state.warm_start
or bucket_index not in state.p_memory_dict
or state.p_memory_dict[bucket_index].shape[0] != total_Ps_size
or state.q_memory_dict[bucket_index].shape[0] != total_Qs_size
):
need_randomize_qs = True
# If warm-start is disabled, low-rank tensors will be initialized at every step.
Expand Down Expand Up @@ -297,7 +294,7 @@ def div_callback(fut):
q_idx += m * matrix_approximation_rank

# If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values.
# The exceptions are the first time and when the buckets are rebuilt.
# The exception is the first iteration when PowerSGD is applied.
if not need_randomize_qs:
for q in qs:
_orthogonalize(q)
Expand Down

0 comments on commit 59c8077

Please sign in to comment.