Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Simplify the implementation of warm-start #50981

Closed
wants to merge 4 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 2 additions & 5 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Expand Up @@ -250,14 +250,11 @@ def div_callback(fut):
total_Ps_size += n * matrix_approximation_rank
total_Qs_size += m * matrix_approximation_rank
# If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible.
# The memory spaces of Ps and Qs need to be (re)allocated at the beginning,
# as well as later whenever the buckets are rebuilt during training.
# The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied.
need_randomize_qs = False
if (
not state.warm_start
or bucket_index not in state.p_memory_dict
or state.p_memory_dict[bucket_index].shape[0] != total_Ps_size
or state.q_memory_dict[bucket_index].shape[0] != total_Qs_size
):
need_randomize_qs = True
# If warm-start is disabled, low-rank tensors will be initialized at every step.
Expand Down Expand Up @@ -297,7 +294,7 @@ def div_callback(fut):
q_idx += m * matrix_approximation_rank

# If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values.
# The exceptions are the first time and when the buckets are rebuilt.
# The exception is the first iteration when PowerSGD is applied.
if not need_randomize_qs:
for q in qs:
_orthogonalize(q)
Expand Down