Skip to content

Commit

Permalink
[Gradient Compression] Replace the assertions in PowerSGD comm hook b…
Browse files Browse the repository at this point in the history
…y stream syncrhonization (#49435)

Summary:
Pull Request resolved: #49435

Previously the assertion that prevents illegal memory access is because of the torch.any that returns a boolean value, which initiates a data transfer from the device to the host and forces a synchronization.

An explicit synchronization is more to the point.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 118664204

Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl

buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_DistributedDataParallel_powerSGD_ddp_comm_hook

Reviewed By: rohan-varma

Differential Revision: D25573484

fbshipit-source-id: 516d0d502da2863b516c15332702335ee662f072
  • Loading branch information
Yi Wang authored and facebook-github-bot committed Dec 21, 2020
1 parent 342bfd8 commit 96aed20
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def decompress(fut):

for p, q, tensor in zip(ps, qs, high_rank_tensors):
torch.matmul(p, q.t(), out=tensor)
assert not torch.any(torch.isnan(tensor))
if torch.cuda.is_available():
torch.cuda.synchronize()

if state.use_error_feedback:
# Memorize the local errors.
Expand Down Expand Up @@ -414,7 +415,8 @@ def decompress(fut):
if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
assert not torch.any(torch.isnan(state.error_dict[bucket_index]))
if torch.cuda.is_available():
torch.cuda.synchronize()
ret = input_tensor.resize_(total_length)
return [ret]

Expand Down

0 comments on commit 96aed20

Please sign in to comment.