Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch Nightly with MPI giving RuntimeError: No backend type associated with device type cuda #109543

Closed
ajindal1 opened this issue Sep 18, 2023 · 5 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@ajindal1
Copy link
Contributor

ajindal1 commented Sep 18, 2023

馃悰 Describe the bug

I have build Pytorch with MPI and while running CUDA Aware MPI test, I am getting a RunTimeError saying No backend type associated with device type CUDA. This occurs with Pytorch nightly version as well as the RC version for Pytorch 2.1.0. It was working fine until Pytorch 2.0.1.

Steps to reproduce the error:

mpirun -np 2 python test_torch_cuda_aware_mpi.py

File: test_torch_cuda_aware_mpi.py

import os
import torch
import torch.distributed as dist

def run(devices):
    world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    assert len(devices) == world_size
    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
    device = devices[rank]
    if device != torch.device('cpu'):
        torch.cuda.set_device(device)
    for _ in range(10):
        if rank == 0:
            # Send the tensor to process 1
            tensor0 = torch.rand(1000000, device=device)
            send_req = dist.isend(tensor0, dst=1)
            # Receive the tensor from process 0
            tensor1 = torch.zeros(1000000, device=device)
            recv_req = dist.irecv(tensor1, src=1)
        else:
            # Send the tensor to process 0
            tensor1 = torch.rand(1000000, device=device)
            send_req = dist.isend(tensor1, dst=0)
            # Receive the tensor from process 1
            tensor0 = torch.zeros(1000000, device=device)
            recv_req = dist.irecv(tensor0, src=0)
        send_req.wait()
        recv_req.wait()
        if rank==0:
            print(rank, tensor0.mean(), tensor1.mean())
        if rank==1:
            print(rank, tensor0.mean(), tensor1.mean())


def init_processes(fn, devices):
    """ Initialize the distributed environment. """
    dist.init_process_group('mpi')
    fn(devices)


if __name__ == "__main__":
    devices = [torch.device('cuda:0'), torch.device('cuda:1')]
    init_processes(run, devices)

Error:

Traceback (most recent call last):
  File "integration_test/test_cuda_aware_mpi/test_torch_cuda_aware_mpi.py", line 49, in <module>
    init_processes(run, devices)
  File "integration_test/test_cuda_aware_mpi/test_torch_cuda_aware_mpi.py", line 44, in init_processes
    fn(devices)
  File "integration_test/test_cuda_aware_mpi/test_torch_cuda_aware_mpi.py", line 22, in run
    send_req = dist.isend(tensor0, dst=1)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1529, in isend
    return default_pg.send([tensor], dst, tag)
RuntimeError: No backend type associated with device type cuda
Traceback (most recent call last):
  File "integration_test/test_cuda_aware_mpi/test_torch_cuda_aware_mpi.py", line 49, in <module>
    init_processes(run, devices)
  File "integration_test/test_cuda_aware_mpi/test_torch_cuda_aware_mpi.py", line 44, in init_processes
    fn(devices)
  File "integration_test/test_cuda_aware_mpi/test_torch_cuda_aware_mpi.py", line 29, in run
    send_req = dist.isend(tensor1, dst=0)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1529, in isend
    return default_pg.send([tensor], dst, tag)
RuntimeError: No backend type associated with device type cuda
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[50595,1],0]
  Exit code:    1
--------------------------------------------------------------------------

Versions

PyTorch version: 2.2.0.dev20230918+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.31

Python version: 3.8.18 (default, Sep 11 2023, 13:40:15) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1046-azure-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla V100-PCIE-16GB
GPU 1: Tesla V100-PCIE-16GB
GPU 2: Tesla V100-PCIE-16GB
GPU 3: Tesla V100-PCIE-16GB

Nvidia driver version: 525.125.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 24
On-line CPU(s) list: 0-23
Thread(s) per core: 1
Core(s) per socket: 12
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 79
Model name: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
Stepping: 1
CPU MHz: 2593.990
BogoMIPS: 5187.98
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 768 KiB
L1i cache: 768 KiB
L2 cache: 6 MiB
L3 cache: 70 MiB
NUMA node0 CPU(s): 0-11
NUMA node1 CPU(s): 12-23
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; Clear CPU buffers; SMT Host state unknown
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt md_clear

Versions of relevant libraries:
[pip3] numpy==1.22.2
[pip3] pytorch-lightning==1.9.3
[pip3] torch==2.2.0.dev20230918+cu118
[pip3] torch-nebula==0.16.5
[pip3] torch-ort==1.15.0
[pip3] torchaudio==2.2.0.dev20230918+cu118
[pip3] torchmetrics==0.11.3
[pip3] torchsnapshot==0.1.0
[pip3] torchvision==0.17.0.dev20230918+cu118
[conda] magma-cuda118 2.6.1 1 pytorch
[conda] mkl 2021.4.0 pypi_0 pypi
[conda] mkl-include 2021.4.0 pypi_0 pypi
[conda] numpy 1.22.2 pypi_0 pypi
[conda] pytorch-lightning 1.9.3 pypi_0 pypi
[conda] torch 2.2.0.dev20230918+cu118 pypi_0 pypi
[conda] torch-nebula 0.16.5 pypi_0 pypi
[conda] torch-ort 1.15.0 pypi_0 pypi
[conda] torchaudio 2.2.0.dev20230918+cu118 pypi_0 pypi
[conda] torchmetrics 0.11.3 pypi_0 pypi
[conda] torchsnapshot 0.1.0 pypi_0 pypi
[conda] torchvision 0.17.0.dev20230918+cu118 pypi_0 pypi

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu

@H-Huang H-Huang added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 18, 2023
@H-Huang
Copy link
Member

H-Huang commented Sep 18, 2023

Can you try a workaround to explicitly specify device support?

dist.init_process_group('mpi') -> dist.init_process_group('cuda:mpi')

I will look into why it is saying mpi does not support cuda devices

@Aidyn-A
Copy link
Collaborator

Aidyn-A commented Sep 18, 2023

For some reason, by default MPI is compatible with CPU only:

backend_capability: Dict[str, List[str]] = {
GLOO : ["cpu", "cuda"],
NCCL : ["cuda"],
UCC : ["cpu", "cuda"],
MPI : ["cpu"],
}

And the compatibility requirements were introduced in this PR:

@ajindal1
Copy link
Contributor Author

@H-Huang I tried changing that, then it started asking for optional arguments like rank, world_size. After that is now asking for MASTER_ADDR. Is this behavior expected?

@Aidyn-A thanks for sharing this, don't understand why this change has been made or if there is a way to support CUDA backend with MPI.

H-Huang added a commit to H-Huang/pytorch that referenced this issue Sep 19, 2023
Summary: Fixes pytorch#109543

Test Plan: We need to run CUDA aware MPI in PyTorch to actually test this change, we currently have no MPI tests.

Differential Revision: D49420438
@H-Huang
Copy link
Member

H-Huang commented Sep 19, 2023

Thanks, #100954 does look like it caused the issue. I am assuming all that is needed to fix the BC change is to add "cuda" like shown in #109614, but we should include tests for initializing MPI in PyTorch to prevent this issue from happening again.

@ajindal1
Copy link
Contributor Author

Thanks, the above solution worked for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants