Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch RPC on multiple nodes with GPU returns a EOF error #95487

Open
vmoens opened this issue Feb 24, 2023 · 6 comments
Open

Torch RPC on multiple nodes with GPU returns a EOF error #95487

vmoens opened this issue Feb 24, 2023 · 6 comments
Labels
module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer module: tensorpipe Related to Tensorpipe RPC Agent oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vmoens
Copy link
Contributor

vmoens commented Feb 24, 2023

馃悰 Describe the bug

When running torch rpc on multiple nodes with submitit (through slurm) I get an EOF error even if I'm not using the gpus and I'm not making them available to RPC.
Here's a script to reproduce:

import os
from torch.distributed import rpc
import torch
import submitit
import socket
import subprocess
import time

MAX_TIME_TO_CONNECT=1000

def rpc_init_node(
    rank,
    rank0_ip,
    tcp_port,
    world_size,
):
    DEVICES=[]  #list(range(torch.cuda.device_count()))
    os.environ["MASTER_ADDR"] = str(rank0_ip)
    os.environ["MASTER_PORT"] = "29500"
    # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
    # os.environ['TP_SOCKET_IFNAME']='lo'
    options = rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=16,
        init_method=f"tcp://{rank0_ip}:{tcp_port}",
        rpc_timeout=MAX_TIME_TO_CONNECT,
        _transports=["uv"],
        # Currently fails when nodes have more than 0 gpus avail,
        # even when no device is made visible
        devices=DEVICES,
    )
    print(f"init rpc on {rank}")
    rpc.init_rpc(
        f"NODE_{rank}",
        rank=rank,
        backend=rpc.BackendType.TENSORPIPE,
        rpc_backend_options=options,
        world_size=world_size,
    )
    rpc.shutdown()

def rpc_init_master(
    tcp_port,
    world_size,
):
    hostname = socket.gethostname()
    rank0_ip = socket.gethostbyname(hostname)

    DEVICES=[]  # list(range(torch.cuda.device_count()))
    os.environ["MASTER_ADDR"] = str(rank0_ip)
    os.environ["MASTER_PORT"] = "29500"
    # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
    # os.environ['TP_SOCKET_IFNAME']='lo'
    options = rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=16,
        init_method=f"tcp://{rank0_ip}:{tcp_port}",
        rpc_timeout=MAX_TIME_TO_CONNECT,
        _transports=["uv"],
        # Currently fails when nodes have more than 0 gpus avail,
        # even when no device is made visible
        devices=DEVICES,
    )
    print("init rpc on master")
    rpc.init_rpc(
        "TRAINER",
        rank=0,
        backend=rpc.BackendType.TENSORPIPE,
        rpc_backend_options=options,
        world_size=world_size,
    )
    # some dummy compute
    out = rpc.rpc_sync("NODE_1", torch.add, args=(torch.ones(()), torch.ones(())))
    rpc.shutdown()
    print("result", out)
    return out.item()

if __name__ == "__main__":
    slurm_conf = {
        "timeout_min": 100,
        "slurm_partition": "train",
        "slurm_cpus_per_task": 4, # works
        #"slurm_gpus_per_task": 1, "slurm_cpus_per_gpu": 8, # does not work
    }
    master_on_node = True
    num_nodes = 2
    executor = submitit.AutoExecutor(folder="log_test")
    executor.update_parameters(**slurm_conf)
    if not master_on_node:
        hostname = socket.gethostname()
        IPAddr = socket.gethostbyname(hostname)
    else:
        job = executor.submit(rpc_init_master, 1234, num_nodes+1)
        print("job id", job.job_id)
        time.sleep(2.0)
        cmd=f"squeue -j {job.job_id} -o %N | tail -1"
        node = subprocess.check_output(cmd, shell=True, text=True).strip()
        print("node", node)
        cmd=f'sinfo -n {node} -O nodeaddr | tail -1'
        print(cmd)
        IPAddr = subprocess.check_output(cmd, shell=True, text=True).strip()
    print("IP addr:", IPAddr)

    for i in range(num_nodes):
        _job = executor.submit(
                rpc_init_node, i + 1, IPAddr, 1234, num_nodes+1)
    if not master_on_node:
        out = rpc_init_master(1234, num_nodes+1)
    else:
        out = job.result()
    print("result", out)

I commented the line that makes the code break if uncommented (you should comment the line above tagged with # works).

###聽What does not matter

  • If you tell which devices RPC should see using devices=list_of_device, or devices=[] the effect is the same.
  • If you launch things from the master node or create a master node (see script for example) the error is the same
  • The code runs using multiprocessing, presumably because I'm staying on the same node (?)

I had to set the _transport in the TensorPipe options because I'm running on AWS and without it it's not running

Here's the error:

Traceback (most recent call last):
  File "/data/home/vmoens/dump/dummy.py", line 109, in <module>
    out = job.result()
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/core.py", line 266, in result
    r = self.results()
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/core.py", line 294, in results
    raise job_exception  # pylint: disable=raising-bad-type
submitit.core.utils.FailedJobError: Job (task=0) failed during processing with trace:
----------------------
Traceback (most recent call last):
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/submission.py", line 54, in process_job
    result = delayed.result()
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/utils.py", line 133, in result
    self._result = self.function(*self.args, **self.kwargs)
  File "/data/home/vmoens/dump/dummy.py", line 63, in rpc_init_master
    rpc.init_rpc(
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/__init__.py", line 199, in init_rpc
    _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/__init__.py", line 234, in _init_rpc_backend
    rpc_agent = backend_registry.init_backend(
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/backend_registry.py", line 104, in init_backend
    return backend.value.init_backend_handler(*args, **kwargs)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/backend_registry.py", line 363, in _tensorpipe_init_backend_handler
    api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 82, in wrapper
    return func(*args, **kwargs)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 224, in _all_gather
    rpc_sync(
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 82, in wrapper
    return func(*args, **kwargs)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 809, in rpc_sync
    return fut.wait()
RuntimeError: EOF: end of file (this error originated at tensorpipe/transport/uv/connection_impl.cc:132)

Versions

Latest torch nightly, locally built

PyTorch version: 2.0.0.dev20230220+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~18.04) 9.4.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.25.0
Libc version: glibc-2.27

Python version: 3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 15:55:03)  [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-1069-aws-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.6.112
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 510.47.03
cuDNN version: Probably one of the following:
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.1.1                                                                                                                                                                                     /usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.1.1                                                                                                                                                                                     HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              96
On-line CPU(s) list: 0-95
Thread(s) per core:  2
Core(s) per socket:  24
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               85
Model name:          Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping:            7
CPU MHz:             1369.306
BogoMIPS:            5999.99
Hypervisor vendor:   KVM
Virtualization type: full
L1d cache:           32K
L1i cache:           32K
L2 cache:            1024K
L3 cache:            36608K
NUMA node0 CPU(s):   0-23,48-71
NUMA node1 CPU(s):   24-47,72-95
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.0.0+c8bfe3f548
[pip3] torch==2.0.0.dev20230220+cu118
[pip3] torchaudio==2.0.0.dev20230222+cu118
[pip3] torchrl==0.0.4+46ec988
[pip3] torchsnapshot==0.1.0
[pip3] torchvision==0.15.0.dev20230221+cu118
[conda] magma-cuda110             2.5.2                         1    pytorch
[conda] mkl                       2022.1.0           hc2b9512_224
[conda] mkl-include               2023.0.0         h84fe81f_26648    conda-forge
[conda] numpy                     1.24.1                   pypi_0    pypi
[conda] pytorch-triton            2.0.0+c8bfe3f548          pypi_0    pypi
[conda] torch                     2.0.0a0+gitd677432           dev_0    <develop>
[conda] torchaudio                2.0.0.dev20230222+cu118          pypi_0    pypi
[conda] torchrl                   0.0.4+46ec988            pypi_0    pypi
[conda] torchsnapshot             0.1.0                    pypi_0    pypi
[conda] torchvision               0.15.0.dev20230221+cu118          pypi_0    pypi

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @pietern @jjlilley @mrzzd @lw @beauby

@H-Huang H-Huang added oncall: distributed Add this issue/PR to distributed oncall triage queue module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer module: tensorpipe Related to Tensorpipe RPC Agent labels Feb 24, 2023
@vmoens
Copy link
Contributor Author

vmoens commented Feb 27, 2023

Here's the error on the node log

[W tensorpipe_agent.cpp:915] RPC agent for NODE_2 encountered error when sending outgoing request #0 to NODE_1: ECONNRESET: connection reset by peer (this error originated at tensorpipe/transport/uv/connection_impl.cc:132)
submitit ERROR (2023-02-27 07:55:58,722) - Submitted job triggered an exception
Traceback (most recent call last):
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/_submit.py", line 11, in <module>
    submitit_main()
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/submission.py", line 72, in submitit_main
    process_job(args.folder)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/submission.py", line 65, in process_job
    raise error
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/submission.py", line 54, in process_job
    result = delayed.result()
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/utils.py", line 133, in result
    self._result = self.function(*self.args, **self.kwargs)
  File "/data/home/vmoens/dump/dummy.py", line 32, in rpc_init_node
    rpc.init_rpc(
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/__init__.py", line 199, in init_rpc
    _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/__init__.py", line 234, in _init_rpc_backend
    rpc_agent = backend_registry.init_backend(
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/backend_registry.py", line 104, in init_backend
    return backend.value.init_backend_handler(*args, **kwargs)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/backend_registry.py", line 363, in _tensorpipe_init_backend_handler
    api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 82, in wrapper
    return func(*args, **kwargs)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 224, in _all_gather
    rpc_sync(
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 82, in wrapper
    return func(*args, **kwargs)
  File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 809, in rpc_sync
    return fut.wait()
RuntimeError: ECONNRESET: connection reset by peer (this error originated at tensorpipe/transport/uv/connection_impl.cc:132)
srun: error: a100-st-p4d24xlarge-14: task 0: Exited with exit code 1

@kwen2501
Copy link
Contributor

"connection reset by peer" means that the peer is exiting due to a failure. Do you know what happened to that peer?

@kwen2501 kwen2501 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 28, 2023
@vmoens
Copy link
Contributor Author

vmoens commented Feb 28, 2023

I have 3 nodes (rank 0: EOF error, rank 2: connection reset).
The last error is probably the most informative

terminate called after throwing an instance of 'std::runtime_error'
  what():  In globalIdxForDevice at tensorpipe/channel/cuda_ipc/context_impl.cc:102 "iter == globalUuids.end()Couldn't find GPU with UUID b951cea7-64f5-946d-a7fd-981149841eed"
srun: error: a100-st-p4d24xlarge-13: task 0: Aborted (core dumped)

This is the stderr of rank 1

@vmoens
Copy link
Contributor Author

vmoens commented Mar 1, 2023

I tried to do a bit more debugging.
As @kwen2501 pointed out only one node is failing. It is always the first node to be launched via submitit.
I was wondering if the unpickling from submitit could be the cause of the bug but it does not seem to be the case.
I checked and the UUID of the cuda device that fails to be found is the UUID of the GPU on the remote note (I thought that perhaps it was a ref to the device on the main node but it is not).

I also tried to get rid of any submitit specific behaviour (like unpickling the target function) by re-writing my example above with the rpc_init_node moved in a separate script that is called on the node by subprocess (like this I'm sure that the workspace is fresh when initializing rpc) but the same error persists.

So TL;DR: when launching multiple jobs with RPC on different nodes with GPU, a failure occurs in tensorpipe/channel/cuda_ipc/context_impl.cc when trying to find the GPU based on on its UUID.

@H-Huang
Copy link
Member

H-Huang commented Mar 3, 2023

I'm able to reproduce this and this is the stack trace of the core dump I get on rank 1

(gdb) bt
#0  __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
#1  0x00007ffff7c4b859 in __GI_abort () at abort.c:79
#2  0x00007fffc237535a in __cxxabiv1::__terminate (handler=<optimized out>)
    at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/libsupc++/eh_terminate.cc:48
#3  0x00007fffc23753c5 in std::terminate ()
    at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/libsupc++/eh_terminate.cc:58
#4  0x00007fffc2375658 in __cxxabiv1::__cxa_throw (obj=<optimized out>, tinfo=0x7fffc24ca220 <typeinfo for std::runtime_error>, 
    dest=0x7fffc238a110 <std::runtime_error::~runtime_error()>)
    at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/libsupc++/eh_throw.cc:95
#5  0x00007fff815cd350 in tensorpipe::ExceptionThrower<std::runtime_error>::~ExceptionThrower (this=0x7fff25ff5f40, 
    __in_chrg=<optimized out>) at ../third_party/tensorpipe/tensorpipe/common/defs.h:48
#6  0x00007fff8161df8f in tensorpipe::channel::cuda_ipc::(anonymous namespace)::globalIdxForDevice (globalUuids=..., uuid=...)
    at ../third_party/tensorpipe/tensorpipe/channel/cuda_ipc/context_impl.cc:102
#7  0x00007fff81620d5c in tensorpipe::channel::cuda_ipc::ContextImpl::canCommunicateWithRemote (this=0x55555ad5c830, 
    localDeviceDescriptor=..., remoteDeviceDescriptor=...)
    at ../third_party/tensorpipe/tensorpipe/channel/cuda_ipc/context_impl.cc:329
#8  0x00007fff815edba6 in tensorpipe::channel::ContextBoilerplate<tensorpipe::channel::cuda_ipc::ContextImpl, tensorpipe::channel::cuda_ipc::ChannelImpl>::canCommunicateWithRemote (this=0x55555ad5c6b0, localDeviceDescriptor=..., remoteDeviceDescriptor=...)
    at ../third_party/tensorpipe/tensorpipe/channel/context_boilerplate.h:119
#9  0x00007fffb64f210c in tensorpipe::(anonymous namespace)::selectChannels (orderedChannels=..., remoteDescriptorsMap=...)
    at ../third_party/tensorpipe/tensorpipe/core/pipe_impl.cc:196
#10 0x00007fffb65017d7 in tensorpipe::PipeImpl::onReadWhileServerWaitingForBrochure (this=0x7fff1c002250, nopBrochure=...)
    at ../third_party/tensorpipe/tensorpipe/core/pipe_impl.cc:1014
#11 0x00007fffb64f3f99 in tensorpipe::PipeImpl::<lambda(tensorpipe::PipeImpl&)>::operator()(tensorpipe::PipeImpl &) const (
    __closure=0x7fff25ff6ac0, impl=...) at ../third_party/tensorpipe/tensorpipe/core/pipe_impl.cc:338
#12 0x00007fffb650c79b in tensorpipe::CallbackWrapper<tensorpipe::PipeImpl>::entryPointFromLoop<tensorpipe::PipeImpl::initFromLoop()::<lambda(tensorpipe::PipeImpl&)> >(tensorpipe::PipeImpl &, tensorpipe::PipeImpl::<lambda(tensorpipe::PipeImpl&)>, const tensorpipe::Error &) (this=0x7fff1c002578, subject=..., fn=..., error=...)

Unfortunately its pretty deep in the tensorpipe library which I don't understand. I am not sure why non-cuda RPC has a cuda context being set up in tensorpipe.

@zw0610
Copy link

zw0610 commented Mar 7, 2024

I can reproduce the issue even with the official demo: https://github.com/pytorch/examples/tree/main/distributed/rpc/ddp_rpc

The coredump occurs at rank 0, where other ranks try to register/membership-updating to the rank 0. Different ranks have to launched at different machines, which means rank 0 cannot see GPUs on other ranks.

During the register/membership-updating, tensorpipe tries to setup CUDA IPC channel between the local devices and devices on the newly joined rank. The CUDA IPC channel shall serve faster connection between GPUs instead of transferring via CPU.

However, it seems tensorpipe creates the GPU uuid list, as well as its p2p support matrix, at its initialization and never update when new rank joins. In this way, tensorpipe iterate its GPU uuid list and try to find the GPU uuid coming from a remote rank, which by all means does not exist on rank 0.

A quick fix is to change the behavior of cuda ipc in tensorpipe. However, the tensorpipe repository is archived. @H-Huang Could you suggest the best approach to patch the archived tensorpipe repository? Or is there any plan to retire tensorpipe in PyTorch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer module: tensorpipe Related to Tensorpipe RPC Agent oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants