Skip to content

Conversation

hongxiayang
Copy link
Collaborator

@hongxiayang hongxiayang commented Mar 14, 2024

When working on testing all-reduce with an alternative rccl replacement backend, my test script crashed. After debugging, I found that ncclGetLastError(NULL) return null, and then the code uses the return value to do std::string would seg-fault with an exception of basic_string::_M_construct null not valid.

This pull request is to fix this edge condition so that it will exit the program gracefully with useful information.

Test:
Before the fix, my test script exits like below:

File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2051, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: basic_string::_M_construct null not valid

After this fix, my test script exited with useful message like,

[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2219, in all_reduce
[rank0]:     work = group.allreduce([tensor], opts)
[rank0]: torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/NCCLUtils.hpp:272, internal error - please report this issue to the NCCL developers, NCCL version 0.4.2
[rank0]: ncclInternalError: Internal check failed.
[rank0]:  Last error: Unknown NCCL Error

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

Copy link

pytorch-bot bot commented Mar 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/121905

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit da0c633 with merge base 5891c5b (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Mar 14, 2024
@hongxiayang hongxiayang marked this pull request as draft March 14, 2024 15:49
@hongxiayang hongxiayang marked this pull request as ready for review March 22, 2024 15:02
@jeffdaily
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased nccl_error_crash onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout nccl_error_crash && git pull --rebase)

@jeffdaily
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 25, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jeffdaily jeffdaily added the topic: not user facing topic category label Mar 25, 2024
pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
…cclErrorDetailStr (#121905)

When working on testing all-reduce with an alternative rccl replacement backend, my test script crashed. After debugging, I found that `ncclGetLastError(NULL)` return null, and then the code uses the return value to do std::string would seg-fault with an exception of `basic_string::_M_construct null not valid`.

This pull request is to fix this edge condition so that it will exit the program gracefully with useful information.

**Test:**
Before the fix, my test script exits like below:
```
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2051, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: basic_string::_M_construct null not valid
```

After this fix, my test script exited with useful message like,
```
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2219, in all_reduce
[rank0]:     work = group.allreduce([tensor], opts)
[rank0]: torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/NCCLUtils.hpp:272, internal error - please report this issue to the NCCL developers, NCCL version 0.4.2
[rank0]: ncclInternalError: Internal check failed.
[rank0]:  Last error: Unknown NCCL Error
```

Pull Request resolved: #121905
Approved by: https://github.com/wconstab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants