Skip to content

Commit

Permalink
[pytorch][PR][Gradient Compression] Reduce the peak memory of fp16 co…
Browse files Browse the repository at this point in the history
…mpression provided by ddp comm hook

The peak memory usage of ddp comm hook has increased due to an extra copy of gradient tensors. To reduce the memory usage, decompress the fp16 tensor in place of the tensor stored in the the gradient bucket.

#Closes: #45968

Differential Revision: [D24178118](https://our.internmc.facebook.com/intern/diff/D24178118/)

[ghstack-poisoned]
  • Loading branch information
wayi committed Oct 9, 2020
1 parent 40828b6 commit 0d9557f
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def fp16_compress_hook(process_group: object, bucket: dist._GradBucket):
).get_future()

def decompress(fut):
return [fut.value()[0].to(torch.float32).div_(world_size)]
decompressed_tensor = bucket.get_tensors()[0]
# Decompress in place to reduce the peak memory.
# See: https://github.com/pytorch/pytorch/issues/45968
decompressed_tensor.copy_(fut.value()[0].to(torch.float32).div_(world_size))
return [decompressed_tensor]

return fut.then(decompress)

Expand Down

0 comments on commit 0d9557f

Please sign in to comment.