Skip to content

Distributed / MPI / CUDA: Incorrect messages received from isend/irecv in PyTorch 1.9 #63486

@tvogels

Description

@tvogels

🐛 Bug

When two workers send each other a number of asynchronous messages with isend/irecv, the received messages are often incorrect. Workers may receive the message they sent themselves, receive the wrong message, or receive nothing at all.

To Reproduce

Setting:

  • Pytorch 1.9.0. The bug does not exist in 1.8.
  • CUDA. The bug does not exist on CPU.
  • MPI (OpenMPI 4.0.4)
  • Two processes on the same GPU. I didn't test with different GPUs or different nodes.

Code:

import torch

torch.distributed.init_process_group("mpi")

device = torch.device("cuda")  # bug does not appear on cpu

rank = torch.distributed.get_rank()
assert torch.distributed.get_world_size() == 2
other_worker = (rank + 1) % 2

for _ in range(10):
    handles = []

    # send two messages to the other worker
    for message in [1, 2]:  # there seems to be no bug if I send only one message
        payload = torch.tensor([10 * rank + message], device=device)  # the first digit indicates the sender
        handles.append(torch.distributed.isend(payload, dst=other_worker, tag=message))

    # receive the messages
    results = []
    for message in [1, 2]:
        recv_buffer = torch.tensor([-1], device=device)  # buffer is initialized at -1
        handles.append(torch.distributed.irecv(recv_buffer, src=other_worker, tag=message))
        results.append(recv_buffer)

    # wait for send and receive operations to finish
    for handle in handles:
        handle.wait()

    # print the received values
    if rank == 0:
        print([r.item() for r in results])  # expecting [11, 12]

Expected output

[11, 12]
[11, 12]
[11, 12]
[11, 12]
[11, 12]
[11, 12]
[11, 12]
[11, 12]
[11, 12]
[11, 12]

Output I observe

[11, 12]
[2, 12]
[-1, 12]
[2, 12]
[2, 12]
[2, 12]
[2, 12]
[2, 12]
[2, 12]
[2, 12]

Environment

PyTorch version: 1.9.0a0+gitd69c22d
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.19.6
Libc version: glibc-2.31

Python version: 3.8.8 (default, Apr 13 2021, 19:58:26)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-5.8.0-1038-gcp-x86_64-with-glibc2.10
Is CUDA available: False
CUDA runtime version: 11.3.109
GPU models and configuration: No devices found.
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.20.1
[pip3] numpydoc==1.1.0
[pip3] torch==1.9.0a0+gitd69c22d
[pip3] torchvision==0.10.0
[conda] blas                      1.0                         mkl
[conda] magma-cuda113             2.5.2                         1    pytorch
[conda] mkl                       2021.2.0           h06a4308_296
[conda] mkl-include               2021.3.0           h06a4308_520
[conda] mkl-service               2.3.0            py38h27cfd23_1
[conda] mkl_fft                   1.3.0            py38h42c9631_2
[conda] mkl_random                1.2.1            py38ha9443f7_2
[conda] mypy_extensions           0.4.3                    py38_0
[conda] numpy                     1.20.1           py38h93e21f0_0
[conda] numpy-base                1.20.1           py38h7d8b39e_0
[conda] numpydoc                  1.1.0              pyhd3eb1b0_1
[conda] torch                     1.9.0a0+gitd69c22d          pypi_0    pypi
[conda] torchvision               0.10.0                   pypi_0    pypi

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu @gcramer23

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: correctness (silent)issue that returns an incorrect result silentlyoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions