Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wait() does not block the default stream for NCCL's asynchronous P2P operations #68866

Open
jasperzhong opened this issue Nov 24, 2021 · 3 comments
Labels
module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@jasperzhong
Copy link

jasperzhong commented Nov 24, 2021

According to the previous issue #68112 , PyTorch will block the default stream until the asynchronous NCCL's operations finish. However, our experience shows that this is not true.

Here is the minimal reproducible code.

import os
from datetime import timedelta

import torch

def main():
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    torch.distributed.init_process_group(
        'nccl', init_method="env://",
        timeout=timedelta(seconds=5)
    )

    rank = torch.distributed.get_rank()
    size = (1000000, )
    for _ in range(100):
        dst_rank = (rank + 1) % 2
        x = torch.ones(size=size, requires_grad=True).cuda()
        y = torch.zeros(size=size, requires_grad=True).cuda()
        send_op = torch.distributed.P2POp(torch.distributed.isend, x,
                                          dst_rank)
        recv_op = torchistributed.P2POp(torch.distributed.irecv, y,
                                          dst_rank)
        reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
        for req in reqs:
            req.wait()

        z = x + y
        assert z[0].item() == 2, "wrong"

    print("right")


if __name__ == '__main__':
    main()

The assert fails as shown in the figure below.
image

However, when we add an explicit cuda device synchronization (i.e., torch.cuda.synchronize()) before z = x + y, it works well.

I notice that torch.cuda.synchronize() is added after the wait() in the downstream repo Megatron-LM (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/p2p_communication.py#L132-L137). So I am not sure whether it is a bug or a feature.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

@ejguan ejguan added oncall: distributed Add this issue/PR to distributed oncall triage queue module: nccl Problems related to nccl support labels Nov 24, 2021
@wanchaol
Copy link
Contributor

Hi @vycezhong thanks for posting the question, If you look into the details of @pritamdamania explanation in #68112 (comment), it explained the work.wait() might not be enough:

Although, note that the execution on CUDA is all still async. So when work.wait() returns the allreduce hasn't finished but we have only instructed stream s1 to wait for s2. Even when res = x + x executes and returns the operation hasn't finished successfully, we are only launching (or queueing) these operations to the GPU. You actually need to use something like https://pytorch.org/docs/stable/generated/torch.cuda.synchronize.html to actually block for all GPU execution to finish.

you can also use torch.distributed.send/recv instead of isend/recv to allow synchronous behavior :)

Closing this issue, feel free to re-open if the question un-resolved.

@rohan-varma
Copy link
Member

IIUC batch_isend_irecv followed by wait() should be equivalent to send/recv, although I think the former APIs have some limitations/bugs that are not ironed out. Should we keep this issue open until we can confirm whether it is a bug or not and track the appropriate fix if so? @wanchaol @pritamdamania87

@jasperzhong
Copy link
Author

jasperzhong commented Dec 6, 2021

FYI. When I use send and recv API instead of bach_isend_irecv, it works well.

import os
from datetime import timedelta

import torch

def main():
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    torch.distributed.init_process_group(
        'nccl', init_method="env://",
        timeout=timedelta(seconds=5)
    )

    rank = torch.distributed.get_rank()
    size = (1000000, )
    for _ in range(100):
        dst_rank = (rank + 1) % 2
        x = torch.ones(size=size, requires_grad=True).cuda()
        y = torch.zeros(size=size, requires_grad=True).cuda()
        torch.distributed.send(x, dst_rank)
        torch.distributed.recv(y, dst_rank)

        z = x + y
        assert z[0].item() == 2, "wrong"

    print("right")


if __name__ == '__main__':
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

5 participants