Skip to content

Commit

Permalink
Update on "[Resubmission][Gradient Compression] Refactor default_hook…
Browse files Browse the repository at this point in the history
…s.py and powerSGD_hook.py by creating a util function that make a vanilla allreduce future"

Resubmission of #51094

Address #50973 (comment)

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26162333](https://our.internmc.facebook.com/intern/diff/D26162333/)

[ghstack-poisoned]
  • Loading branch information
wayi committed Jan 30, 2021
1 parent ea29d8e commit f08cd46
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import numpy as np
import torch
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default

from . import default_hooks as default


def _orthogonalize(matrix, epsilon=1e-8):
Expand Down Expand Up @@ -204,7 +205,9 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
total_length
)
)
state.error_dict[bucket_index] = torch.zeros(total_length, device=device, dtype=dtype)
state.error_dict[bucket_index] = torch.zeros(
total_length, device=device, dtype=dtype
)

# Keep a copy of the input tensor,
# so that we can compute the local error caused by compression later,
Expand Down

0 comments on commit f08cd46

Please sign in to comment.