Skip to content

Conversation

shuqiangzhang
Copy link
Contributor

@shuqiangzhang shuqiangzhang commented Jul 19, 2024

Stack from ghstack (oldest at bottom):

Summary:
Need another dispacher macro to support more data types
Test Plan:
(sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (86fcae11)]$ python
test/distributed/test_c10d_nccl.py -k test_nan_assert_bfloat16
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [85,0,0] Assertion
!isnan(data[i]) failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [18,0,0] Assertion
!isnan(data[i]) failed.
NCCL version 2.21.5+cuda12.0

devgpu009:1193787:1193787 [0] init.cc:1773 NCCL WARN Cuda failure
'device-side assert triggered'
.

Ran 1 test in 9.416s

OK
Reviewers:

Subscribers:

Tasks:

Tags:

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Summary:
Need another dispacher macro to support more data types
Test Plan:
(sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (86fcae11)]$ python
test/distributed/test_c10d_nccl.py -k test_nan_assert_bfloat16
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [85,0,0] Assertion
`!isnan(data[i])` failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [18,0,0] Assertion
`!isnan(data[i])` failed.
NCCL version 2.21.5+cuda12.0

devgpu009:1193787:1193787 [0] init.cc:1773 NCCL WARN Cuda failure
'device-side assert triggered'
.
----------------------------------------------------------------------
Ran 1 test in 9.416s

OK
Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jul 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131131

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fa482bf with merge base 8bf0be7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Jul 19, 2024
Summary:
Need another dispacher macro to support more data types
Test Plan:
(sqzhang_1) [sqzhangdevgpu009.cln1 ~/pytorch (86fcae11)]$ python
test/distributed/test_c10d_nccl.py -k test_nan_assert_bfloat16
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [85,0,0] Assertion
`!isnan(data[i])` failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [18,0,0] Assertion
`!isnan(data[i])` failed.
NCCL version 2.21.5+cuda12.0

devgpu009:1193787:1193787 [0] init.cc:1773 NCCL WARN Cuda failure
'device-side assert triggered'
.
----------------------------------------------------------------------
Ran 1 test in 9.416s

OK
Reviewers:

Subscribers:

Tasks:

Tags:

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jul 19, 2024
Summary:
Need another dispacher macro to support more data types
Test Plan:
(sqzhang_1) [sqzhangdevgpu009.cln1 ~/pytorch (86fcae11)]$ python
test/distributed/test_c10d_nccl.py -k test_nan_assert_bfloat16
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [85,0,0] Assertion
`!isnan(data[i])` failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [18,0,0] Assertion
`!isnan(data[i])` failed.
NCCL version 2.21.5+cuda12.0

devgpu009:1193787:1193787 [0] init.cc:1773 NCCL WARN Cuda failure
'device-side assert triggered'
.
----------------------------------------------------------------------
Ran 1 test in 9.416s

OK
Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: fbd550c
Pull Request resolved: #131131
@shuqiangzhang
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 19, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Inspiring me that this is something we should have in GLOO as well

@shuqiangzhang
Copy link
Contributor Author

@pytorchbot merge -f "all checks passed"

@shuqiangzhang
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

bigfootjon pushed a commit that referenced this pull request Jul 24, 2024
Summary:
Need another dispacher macro to support more data types
Test Plan:
(sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (86fcae11)]$ python
test/distributed/test_c10d_nccl.py -k test_nan_assert_bfloat16
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [85,0,0] Assertion
`!isnan(data[i])` failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [18,0,0] Assertion
`!isnan(data[i])` failed.
NCCL version 2.21.5+cuda12.0

devgpu009:1193787:1193787 [0] init.cc:1773 NCCL WARN Cuda failure
'device-side assert triggered'
.
----------------------------------------------------------------------
Ran 1 test in 9.416s

OK
Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: #131131
Approved by: https://github.com/wconstab, https://github.com/XilunWu

(cherry picked from commit 406f510)
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
Summary:
Need another dispacher macro to support more data types
Test Plan:
(sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (86fcae11)]$ python
test/distributed/test_c10d_nccl.py -k test_nan_assert_bfloat16
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [85,0,0] Assertion
`!isnan(data[i])` failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:18:
checkForNaN: block: [0,0,0], thread: [18,0,0] Assertion
`!isnan(data[i])` failed.
NCCL version 2.21.5+cuda12.0

devgpu009:1193787:1193787 [0] init.cc:1773 NCCL WARN Cuda failure
'device-side assert triggered'
.
----------------------------------------------------------------------
Ran 1 test in 9.416s

OK
Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#131131
Approved by: https://github.com/wconstab, https://github.com/XilunWu
@github-actions github-actions bot deleted the gh/shuqiangzhang/39/head branch August 23, 2024 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants