Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible deadlock in dist.init_process_group #9696

Closed
yaroslavvb opened this issue Jul 22, 2018 · 12 comments
Closed

Possible deadlock in dist.init_process_group #9696

yaroslavvb opened this issue Jul 22, 2018 · 12 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@yaroslavvb
Copy link
Contributor

I'm observing hangs in dist.init_process_group. This happens reliably (100% of the time for me) when launching PyTorch distributed training runs on AWS using official DLAMI.

It goes away when I pre-warm the volume. This workaround makes PyTorch startup much faster, hence I suspect the failure is caused by some handshake logic not being robust to variability in distributed worker timings.

Using NCCL version 2.1.15+cuda9.1, and Amazon Deep Learning AMI v11

Looking at strace, I see some workers stuck in recvfrom, while others are waiting on accept4

[pid 10047] 21:59:49 accept4(32,  <unfinished ...>
[pid  9976] 21:59:49 recvfrom(34,
[pid  9985] 21:59:12 recvfrom(34,  <unfinished ...>
[pid 10058] 21:59:12 accept4(32,
@zou3519 zou3519 added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jul 23, 2018
@zou3519
Copy link
Contributor

zou3519 commented Jul 23, 2018

Could you provide a code snippet that reproduces this?

@yaroslavvb
Copy link
Contributor Author

yaroslavvb commented Jul 23, 2018

Sorry for not isolating the issue, this seems to only happen in a large-scale run and I don't have a small repro. Complicated repro is to follow instructions in
https://github.com/diux-dev/cluster/tree/master/pytorch and then run

python launch_nv.py --name yaro8 --num-tasks 8 --zone us-west-2c --params x8ar_args

I'll update this bug if I can isolate this further

@ailzhang
Copy link
Contributor

Hi @yaroslavvb , what kind of pre-warm did you use?

Also could you give this a try and see if it helps? https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L72

@yaroslavvb
Copy link
Contributor Author

My "warm-up" is to attach an existing EBS volume from a previously successful run, instead of initializing new volume from AMI.

Another workflow that works is this

  1. run it on all machines immediately after instance allocation (as soon as SSH succeeds)
  2. things get stuck in "dist.init_process_group" (>20 mins)
  3. kill all Python processes, repeat commands again
  4. things now work

@ailzhang
Copy link
Contributor

@yaroslavvb This is a bit hard for us to repro. A smaller repro will definitely help. but if it's not possible, next time it happens to you, could you get a full gdb bt trace of all process & threads? For each process, thread apply all bt should do the work.

@yaroslavvb
Copy link
Contributor Author

Here's an example of things hanging, it looks stuck in discoverMaster

Thread 2 (Thread 0x7f5a56fd7700 (LWP 10215)):
#0  0x00007f5aa97218c8 in accept4 (fd=32, addr=..., addr_len=0x7f5a56fd6e18, flags=524288)
    at ../sysdeps/unix/sysv/linux/accept4.c:40
#1  0x00007f5a96a153aa in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f5a96a07c1d in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f5a96a16018 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f5aa99ea6ba in start_thread (arg=0x7f5a56fd7700) at pthread_create.c:333
#5  0x00007f5aa972041d in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:109

Thread 1 (Thread 0x7f5aaa017740 (LWP 10141)):
#0  0x00007f5aa99f387f in __libc_recv (fd=34, buf=0x7ffd921a2198, n=8, flags=0)
    at ../sysdeps/unix/sysv/linux/x86_64/recv.c:28
#1  0x00007f5a897b3f90 in void thd::recv_bytes<unsigned long>(int, unsigned long*, unsigned long) ()
   from /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#2  0x00007f5a897c2cb4 in thd::discoverMaster(std::vector<std::string, std::allocator<std::string> >, unsigned short) ()
---Type <return> to continue, or q <return> to quit---
  te-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#3  0x00007f5a897b91bc in thd::init::initEnv(std::string, int, std::string, int) ()
   from /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#4  0x00007f5a897b80d7 in std::_Function_handler<thd::InitMethod::Config (std::string, int, std::string, int), thd::InitMethod::Config (*)(std::string, int, std::string, int)>::_M_invoke(std::_Any_data const&, std::string, int, std::string, int) ()
   from /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#5  0x00007f5a897b7781 in thd::getInitConfig(std::string, int, std::string, int) ()
   from /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#6  0x00007f5a8977a3ef in thd::DataChannel::newChannel(THDChannelType, std::string, int, std::string, int) ()
   from /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#7  0x00007f5a897792de in THDProcessGroupInit ()
   from /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#8  0x00007f5a89606d1b in THDPModule_initProcessGroup (_unused=<optimized out>, 
    args=<optimized out>) at torch/csrc/distributed/Module.cpp:110
#9  0x00005621b6b4dad1 in _PyCFunction_FastCallDict ()
#10 0x00005621b6bdd67c in call_function ()
#11 0x00005621b6bffcba in _PyEval_EvalFrameDefault ()
#12 0x00005621b6bd6a94 in _PyEval_EvalCodeWithName ()
#13 0x00005621b6bd7941 in fast_function ()
#14 0x00005621b6bdd755 in call_function ()
#15 0x00005621b6c00a7a in _PyEval_EvalFrameDefault ()
#16 0x00005621b6bd770b in fast_function ()
#17 0x00005621b6bdd755 in call_function ()
#18 0x00005621b6bffcba in _PyEval_EvalFrameDefault ()
#19 0x00005621b6bd8459 in PyEval_EvalCodeEx ()
#20 0x00005621b6bd91ec in PyEval_EvalCode ()
#21 0x00005621b6c539a4 in run_mod ()
#22 0x00005621b6c53da1 in PyRun_FileExFlags ()
#23 0x00005621b6c53fa4 in PyRun_SimpleFileExFlags ()
#24 0x00005621b6c57a9e in Py_Main ()
#25 0x00005621b6b1f4be in main ()

@Stonesjtu
Copy link
Contributor

You can test whether your master port is reachable from worker nodes

@ailzhang
Copy link
Contributor

@yaroslavvb this is likely because AWS give you some machines within different subnets so not all of them are reachable from each other.

@yaroslavvb
Copy link
Contributor Author

I don't think it explains it because this problem is fixed by killing Python processes and restarting them on same instances.

I'm explicitly specifying zone to use, launching all instances into a single placement group. To be safe, I additionally mark master port to be reachable from public internet

@ailzhang
Copy link
Contributor

Hi @yaroslavvb , do you compile pytorch from source? If so, could you add a few print statements around

std::tie(sockets[i], std::ignore) = accept(listen_socket);

and
socket = connect(address, port, true, 2000);
.
The GDB trace indicates the master is still waiting on accept() but the worker is waiting on recv a message from master. Every time master accepts a connection, it confirms by sending back the work's address. Since it's hard to repro on our side, if you could provide in which line the master & workers got stuck respectively, that could be very helpful as well.
Thanks!

@yaroslavvb
Copy link
Contributor Author

I'll close this issue for now since I have not seen this error recently, it may have been fixed in master

@brando90
Copy link

brando90 commented Mar 5, 2021

I'll close this issue for now since I have not seen this error recently, it may have been fixed in master

how did you solve this?

I am having a similiar issue but I am unable to figure out what it is.

def setup_process(rank, world_size, backend='gloo'):
    """
    Initialize the distributed environment (for each process).

    gloo: is a collective communications library (https://github.com/facebookincubator/gloo). My understanding is that
    it's a library/API for process to communicate/coordinate with each other/master. It's a backend library.

    export NCCL_SOCKET_IFNAME=eth0
    export NCCL_IB_DISABLE=1

    https://stackoverflow.com/questions/61075390/about-pytorch-nccl-error-unhandled-system-error-nccl-version-2-4-8

    https://pytorch.org/docs/stable/distributed.html#common-environment-variables
    """
    if rank != -1:  # -1 rank indicates serial code
        print(f'setting up rank={rank} (with world_size={world_size})')
        # MASTER_ADDR = 'localhost'
        MASTER_ADDR = '127.0.0.1'
        MASTER_PORT = find_free_port()
        # set up the master's ip address so this child process can coordinate
        os.environ['MASTER_ADDR'] = MASTER_ADDR
        print(f"{MASTER_ADDR=}")
        os.environ['MASTER_PORT'] = MASTER_PORT
        print(f"{MASTER_PORT}")

        # - use NCCL if you are using gpus: https://pytorch.org/tutorials/intermediate/dist_tuto.html#communication-backends
        if torch.cuda.is_available():
            # unsure if this is really needed
            # os.environ['NCCL_SOCKET_IFNAME'] = 'eth0'
            # os.environ['NCCL_IB_DISABLE'] = '1'
            backend = 'nccl'
        print(f'{backend=}')
        # Initializes the default distributed process group, and this will also initialize the distributed package.
        dist.init_process_group(backend, rank=rank, world_size=world_size)
        # dist.init_process_group(backend, rank=rank, world_size=world_size)
        # dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
        print(f'--> done setting up rank={rank}')
        dist.destroy_process_group()

it's easy to reproduce just run the above for each

spawn_return = mp.spawn(fn=setup_process, args=(opts,4,), nprocs=4)

For me it's on my own signle machine

related: https://stackoverflow.com/questions/66498045/how-to-solve-dist-init-process-group-from-hanging-or-deadlocks
#53395

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

5 participants