Skip to content

Conversation

@H-Huang
Copy link
Member

@H-Huang H-Huang commented Sep 19, 2023

Summary: Point to point ops don't enqueue their work to the workMetaList_ which means that the NCCL watchdog does not watch over them, hence they do not respect the collective timeouts.

Test Plan:
While trying to add a test I found we dont have tests which validate the nccl watch dog. It looks like this is because we dont have a good way to detect when nccl watchdog has thrown an error (exception is thrown in a side thread) in our testing framework / MultiprocessTestCase

I manually tested this change with the script in #109401, but need to look more closely at how to automate a test for NCCL watchdog

Differential Revision: D49418976

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 19, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109611

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 601093b with merge base d04b35e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49418976

@wconstab
Copy link
Contributor

wconstab commented Sep 23, 2023

tried using this on my PP hang. it seems like it partially works (watchdog is picking up on a timeout for a RECV op. However it is giving me this

[E ProcessGroupNCCL.cpp:474] [Rank 1] Watchdog caught collective operation timeout
: WorkNCCL(SeqNum=0, OpType=RECV, NumelIn=80000, NumelOut=80000, Timeout(ms)=2000$
) ran for 20175 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:474] [Rank 0] Watchdog caught collective operation timeout
: WorkNCCL(SeqNum=0, OpType=RECV, NumelIn=80000, NumelOut=80000, Timeout(ms)=20000
) ran for 20189 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:988] Failed to retrieve NCCL_DESYNC_DEBUG report.  Please 
file an issue. Error: traceMap[thisSeq][myRank].second == kEventStart INTERNAL ASS
ERT FAILED at "/data/users/whc/pytorch/torch/csrc/distributed/c10d/TraceUtils.h":2
44, please report a bug to PyTorch. Timeout rank [0] last trace item must be kEven
tStart. thisSeq = 0, col = RECV

it's suspicious that it says the seq number is 0, as i've already executed a bunch of steps of forward/backward and each step should be waiting on completion of a previous send/recv.

Summary:

Point to point ops don't enqueue their work to the `workMetaList_` which means that the NCCL watchdog does not watch over them, hence they do not respect the collective timeouts.

Test Plan:
While trying to add a test I found we dont have tests which validate the nccl watch dog. It looks like this is because we dont have a good way to detect when nccl watchdog has thrown an error (exception is thrown in a side thread) in our testing framework / `MultiprocessTestCase`

I manually tested this change with the script in pytorch#109401, but need to look more closely at how to automate a test for NCCL watchdog

Differential Revision: D49418976
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49418976

@H-Huang
Copy link
Member Author

H-Huang commented Sep 29, 2023

Thanks @wconstab, it does look like sequence # also needs to be updated. That might have also caused the issue with desync debug since it reads the seq #, i will double check it

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49418976

@wconstab
Copy link
Contributor

wconstab commented Oct 3, 2023

ok this works for me now-- at least in a trivial case where i intentionally cause a hang, i get a reasonable output from the desync report:

[E ProcessGroupNCCL.cpp:474] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=RECV, NumelIn=6, NumelOut=6, Timeout(ms)=20000) ran for 2075
7 milliseconds before timing out.                                                                                                                                         
[E ProcessGroupNCCL.cpp:474] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=3, OpType=RECV, NumelIn=6, NumelOut=6, Timeout(ms)=20000) ran for 2070
2 milliseconds before timing out.
Done                                                                                 
[E ProcessGroupNCCL.cpp:986]     
         - [1] Timeout at collective: RECV, #2                                                                                                                            
         - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
           - [1] joined but didn't finish collective #2 (count from 1)
         - Snapshot of ranks' latest states:
           #2 started ranks:
             [1] started RECV
           #3 started ranks:
             [0] started RECV
[E ProcessGroupNCCL.cpp:488] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupt
ed/incomplete data.
[E ProcessGroupNCCL.cpp:494] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:915] [Rank 1] NCCL watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=REC
V, NumelIn=6, NumelOut=6, Timeout(ms)=20000) ran for 20757 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 1] NCCL watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=RECV, NumelIn=6, Nume
lOut=6, Timeout(ms)=20000) ran for 20757 milliseconds before timing out.
devgpu012:611476:616545 [0] NCCL INFO [Service thread] Connection closed by localRank 0
Done
devgpu012:611476:613548 [0] NCCL INFO comm 0x7c78920 rank 0 nranks 2 cudaDev 0 busId 11000 - Abort COMPLETE
[E ProcessGroupNCCL.cpp:986] 
         - [0] Timeout at collective: RECV, #3
         - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
           - [1] joined but didn't finish collective #2 (count from 1)
         - Snapshot of ranks' latest states:
           #2 started ranks:
             [1] started RECV
           #3 started ranks:
             [0] started RECV
[E ProcessGroupNCCL.cpp:488] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupt
ed/incomplete data.
[E ProcessGroupNCCL.cpp:494] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:915] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=3, OpType=REC
V, NumelIn=6, NumelOut=6, Timeout(ms)=20000) ran for 20702 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=3, OpType=RECV, NumelIn=6, Nume
lOut=6, Timeout(ms)=20000) ran for 20702 milliseconds before timing out.
Traceback (most recent call last):
  File "/data/users/whc/pytorch/hang.py", line 44, in <module>
    mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fixing @H-Huang

@H-Huang
Copy link
Member Author

H-Huang commented Oct 3, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

acphile pushed a commit to acphile/pytorch that referenced this pull request Jan 22, 2024
Summary: Point to point ops don't enqueue their work to the `workMetaList_` which means that the NCCL watchdog does not watch over them, hence they do not respect the collective timeouts.

Test Plan:
While trying to add a test I found we dont have tests which validate the nccl watch dog. It looks like this is because we dont have a good way to detect when nccl watchdog has thrown an error (exception is thrown in a side thread) in our testing framework / `MultiprocessTestCase`

I manually tested this change with the script in pytorch#109401, but need to look more closely at how to automate a test for NCCL watchdog

Differential Revision: D49418976

Pull Request resolved: pytorch#109611
Approved by: https://github.com/wconstab
acphile pushed a commit to acphile/pytorch that referenced this pull request Jan 22, 2024
Summary: Point to point ops don't enqueue their work to the `workMetaList_` which means that the NCCL watchdog does not watch over them, hence they do not respect the collective timeouts.

Test Plan:
While trying to add a test I found we dont have tests which validate the nccl watch dog. It looks like this is because we dont have a good way to detect when nccl watchdog has thrown an error (exception is thrown in a side thread) in our testing framework / `MultiprocessTestCase`

I manually tested this change with the script in pytorch#109401, but need to look more closely at how to automate a test for NCCL watchdog

Differential Revision: D49418976

Pull Request resolved: pytorch#109611
Approved by: https://github.com/wconstab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants