Skip to content

Commit

Permalink
Update on "Resubmit: [Gradient Compression] Implement the original la…
Browse files Browse the repository at this point in the history
…yerwise PowerSGD"

Differential Revision: [D25654961](https://our.internmc.facebook.com/intern/diff/D25654961/)

[ghstack-poisoned]
  • Loading branch information
wayi committed Dec 19, 2020
1 parent 36b1713 commit 4ca1014
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2822,18 +2822,20 @@ def test_DistributedDataParallel_non_default_stream(self):

@unittest.skipIf(
BACKEND != "nccl",
"Only NCCL backend support DistributedDataParallel",
"Only NCCL backend supports DDP communication hook",
)
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
@skip_if_rocm
def test_DistributedDataParallel_powerSGD_ddp_comm_hook(self):
stream = torch.cuda.Stream(self.rank)
rank = self.rank
rank_to_GPU = self._init_multigpu_helper()
gpus = list(rank_to_GPU[rank])
with torch.cuda.stream(stream):
net = torch.nn.parallel.DistributedDataParallel(
torch.nn.Linear(1, 5).to(rank), device_ids=[rank]
)
process_group = torch.distributed.new_group([0, 1])
process_group = torch.distributed.new_group(gpus)
state = powerSGD.PowerSGDState(
process_group=process_group, matrix_approximation_rank=1
)
Expand Down

0 comments on commit 4ca1014

Please sign in to comment.