-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torchrun c10d backend doesn't seem to work with python 3.12, giving segmentation fault because of calling obmalloc without holding GIL #125990
Comments
Thank you for your bug report. I can reproduce the crash in a clean Python 3.12 environment. |
@kurman is this bug specific to one rendezvous method? I'm not sure why that would be the case but if so I wonder if we are planning to keep this rendezvous method after cleaning up/consolidation work? |
wonder if this issue can be reproduced when specifying |
Issue still reproduces with
|
Tried isolating Store type using single test and all of them are segfaulting: pytest test/distributed/test_store.py -k "FileStoreTest and test_compare_set" |
Basic repro on TCP store (both libuv and non-libuv):
GDB:
|
@TanyaAdams1 Thanks a lot for the debugging info. Yeah the update of Per-Interpreter GIL in 3.12 is causing issues. Do you think adding a python-level lock would solve the issue? |
FYI running the repro with a debug build of CPython points to the real issue: you're calling into cpython APIs without holding the GIL. See the logs below for full details:
|
🐛 Describe the bug
TLDR: It seems like Python 3.12 updated the way GIL works, and now using torch distributed (especially c10d rdzv backend) will trigger a segmentation fault. After debugging, I believe that this error was triggered by calling object allocation function without holding GIL.
To reproduce this bug, first create any new conda environment:
conda create -n torch
, then follow the installation instruction on torch website:conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
. During this step,conda
will by default download a very new version of python (which is python 3.12.3 for me), then run torchrun with any random script name:torchrun --standalone --nproc-per-node 4 random_name.py
(because the program will crash even before launching the script!) Here's the error message I got:I tried to debug this using gdb:
gdb --args python -m torch.distributed.launch --standalone --nproc-per-node 4 random_name.py
, and here's the output:Downgrading python back to 3.10 solves the problem for me now, but given that 3.12.3 is downloaded by conda by default, updating how pytorch handles GIL should be the right way to go.
Versions
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k
The text was updated successfully, but these errors were encountered: