Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP fp16_compress_hook communication hook increases peak memory #45968

Closed
david-macleod opened this issue Oct 7, 2020 · 6 comments
Closed

DDP fp16_compress_hook communication hook increases peak memory #45968

david-macleod opened this issue Oct 7, 2020 · 6 comments
Assignees
Labels
module: ddp Issues/PRs related distributed data parallel training oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@david-macleod
Copy link

david-macleod commented Oct 7, 2020

🐛 Bug

When using the new fp16_compress_hook in torch 1.7 nightly the peak memory usage increases by an amount equal to the (in memory) size of the gradient tensors. In the example below the size of the 32bit grads are 8MB so the expected peak allocated memory is ~24MB (parameter tensors + gradient tensors + DDP buffers), and I have confirmed this is the case in the standard setting.

However when using the hook this increased to 32MB, presumably caused by the bucket.get_tensors()[0].to(torch.float16) copy, I would have expected this memory to be released after the synchronization was complete, but that does not seem to be the case. Is there a way to avoid this extra copy as it becomes an issue when dealing with very large models.

To Reproduce

Steps to reproduce the behavior:

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp


def main(rank, world_size):

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    model = torch.nn.Linear(512, 4096).to(rank)

    ddp_model = DDP(model, device_ids=[rank])
    ddp_model._register_comm_hook(state=None, hook=fp16_compress_hook) 
    optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1e-4)

    for _ in range(5):
        y = ddp_model (torch.randn(64, 512, device=rank)).mean()
        y.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    print(rank, torch.cuda.memory_allocated(device=rank) / (1024 ** 2))
    dist.destroy_process_group()

# as per https://github.com/pytorch/pytorch/blob/master/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
def fp16_compress_hook(process_group: object, bucket: dist._GradBucket):
    """
        This DDP communication hook implements a simple gradient compression
        approach that converts ``GradBucket`` tensors whose type is assumed to be
        ``torch.float32`` to half-precision floating point format (``torch.float16``).
        It allreduces those ``float16`` gradient tensors. Once compressed gradient
        tensors are allreduced, its then callback called ``decompress`` converts the
        aggregated result back to ``float32`` and takes the mean.
        Example::
            >>> ddp_model._register_comm_hook(process_group, fp16_compress_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    world_size = (
        process_group.size() if process_group is not None else dist.get_world_size()
    )

    compressed_tensor = bucket.get_tensors()[0].to(torch.float16)

    fut = dist.all_reduce(
        compressed_tensor, group=group_to_use, async_op=True
    ).get_future()

    def decompress(fut):
        return [fut.value()[0].to(torch.float32).div_(world_size)]

    return fut.then(decompress)

if __name__ == "__main__":
    mp.spawn(main, args=(2,), nprocs=2, join=True)

Without fp16 hook

0 24.04736328125
1 24.04736328125

With fp16 hook

1 32.06298828125
0 32.06298828125

Expected behavior

The peak memory usage to be the same with and without the hook

Environment

Collecting environment information...
PyTorch version: 1.7.0.dev20201001
Is debug build: True
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: Could not collect

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti
GPU 4: GeForce RTX 2080 Ti
GPU 5: GeForce RTX 2080 Ti
GPU 6: GeForce RTX 2080 Ti
GPU 7: GeForce RTX 2080 Ti
GPU 8: GeForce RTX 2080 Ti
GPU 9: GeForce RTX 2080 Ti

Nvidia driver version: 440.100
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.2
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] pytorch-memlab==0.2.1
[pip3] torch==1.7.0.dev20201001
[pip3] torchaudio==0.7.0.dev20200918
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.8.0.dev20200910

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski

@rohan-varma rohan-varma added oncall: distributed Add this issue/PR to distributed oncall triage queue module: ddp Issues/PRs related distributed data parallel training labels Oct 7, 2020
@rohan-varma
Copy link
Member

I guess this might be because we create a new gradient tensor within the hook, but still keep the old tensor around and thus occupy an extra gradient_tensor_size of memory. cc @pritamdamania87 @SciPioneer, is there anything we could do about this, such as make the old bucket tensors point to this new tensor and thereby free that memory?

@pritamdamania87
Copy link
Contributor

@rohan-varma I think the additional memory overhead should be temporary since we clear out the bucket views here each time: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/reducer.cpp#L916.

I would have expected this memory to be released after the synchronization was complete

@david-macleod Which synchronization are you referring to here? Looking at the implementation, it seems like we would allocate additional memory during the backward pass but it would cleared out in the next backward pass. Are you referring to the fact that we don't clear out this memory at the end of the backward pass instead?

@david-macleod
Copy link
Author

david-macleod commented Oct 7, 2020

Sorry my original message wasn't very clear, I meant that during the backwards pass after each bucket is reduced the allocated memory for that particular bucket is released ideally, or that the original buffers are reused as @rohan-varma suggested.

The big issue for me is that the peak memory usage has gone from 3*G to 4*G (where G is gradients size) and that limits the size of the model I can fit in memory.

@wayi1
Copy link
Contributor

wayi1 commented Oct 7, 2020

To resolve this issue, I think I can add a clear() method to GradBucket class for clearing the tensors, and call clear() right after the following compression line:
compressed_tensor = bucket.get_tensors()[0].to(torch.float16)

wayi1 pushed a commit that referenced this issue Oct 9, 2020
…mpression provided by ddp comm hook

The peak memory usage of ddp comm hook has increased due to an extra copy of gradient tensors. To reduce the memory usage, decompress the fp16 tensor in place of the tensor stored in the the gradient bucket.

#Closes: #45968

Differential Revision: [D24178118](https://our.internmc.facebook.com/intern/diff/D24178118/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this issue Oct 9, 2020
…mpression provided by ddp comm hook

The peak memory usage of ddp comm hook has increased due to an extra copy of gradient tensors. To reduce the memory usage, decompress the fp16 tensor in place of the tensor stored in the the gradient bucket.

#Closes: #45968

Differential Revision: [D24178118](https://our.internmc.facebook.com/intern/diff/D24178118/)

ghstack-source-id: 113935840
Pull Request resolved: #46078
wayi1 pushed a commit that referenced this issue Oct 9, 2020
… of fp16 compression provided by ddp comm hook"

The peak memory usage of ddp comm hook has increased due to an extra copy of gradient tensors. To reduce the memory usage, decompress the fp16 tensor in place of the tensor stored in the the gradient bucket.

#Closes: #45968

Differential Revision: [D24178118](https://our.internmc.facebook.com/intern/diff/D24178118/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this issue Oct 9, 2020
…mpression provided by ddp comm hook

Pull Request resolved: #46078

The peak memory usage of ddp comm hook has increased due to an extra copy of gradient tensors. To reduce the memory usage, decompress the fp16 tensor in place of the tensor stored in the the gradient bucket.

#Closes: #45968
ghstack-source-id: 113996453

Differential Revision: [D24178118](https://our.internmc.facebook.com/intern/diff/D24178118/)
@wayi1
Copy link
Contributor

wayi1 commented Oct 9, 2020

@david-macleod Could you patch this PR to verify the fix?
#46078

I verified the fix on my end already.

@david-macleod
Copy link
Author

Works great, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: ddp Issues/PRs related distributed data parallel training oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

4 participants