-
Notifications
You must be signed in to change notification settings - Fork 21.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[c10d] Add an option for NAN check on every collective (#125726)
Summary: The NAN CHECK is done through device side assert without copying needed from GPU to CPU Test Plan: Unit test for collectives that should experience run time error (sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (38f5143e)]$ python test/distributed/test_c10d_nccl.py ProcessGroupNCCLTest.test_nan_assert /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)` failed. [rank0]:[E507 17:31:56.885473996 Utils.cu:30] CUDA error during checkForNan: device-side assert triggered /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)` failed. [rank1]:[E507 17:31:56.128961534 Utils.cu:30] CUDA error during checkForNan: device-side assert triggered . ---------------------------------------------------------------------- Ran 1 test in 7.723s OK Tags: Pull Request resolved: #125726 Approved by: https://github.com/kwen2501
- Loading branch information
1 parent
0214711
commit c860df5
Showing
7 changed files
with
92 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <torch/csrc/distributed/c10d/Utils.hpp> | ||
#include <torch/torch.h> | ||
#include <algorithm> | ||
|
||
namespace c10d { | ||
|
||
// CUDA kernel to check if data has NAN, device side assert | ||
// is raised if NAN is found | ||
template <typename T> | ||
__global__ void checkForNaN(T* data, size_t size) { | ||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
size_t stride = blockDim.x * gridDim.x; | ||
|
||
for (size_t i = tid; i < size; i += stride) { | ||
CUDA_KERNEL_ASSERT(!isnan(data[i])); | ||
} | ||
} | ||
|
||
// CHECK if a Tensor contains NAN in any of its element | ||
void checkForNan(const at::Tensor& tensor) { | ||
// skip check for non float types | ||
if (!torch::is_floating_point(tensor)) { | ||
return; | ||
} | ||
const size_t maxNumThreadsPerBlock = 256; | ||
const size_t maxNumBlocks = 24; | ||
const size_t numThreadsPerBlock = | ||
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel()); | ||
|
||
const size_t numBlocks = std::min<size_t>( | ||
maxNumBlocks, | ||
(tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock); | ||
|
||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor.scalar_type(), "checkForNaN", [&] { | ||
checkForNaN<scalar_t><<<numBlocks, numThreadsPerBlock>>>( | ||
tensor.data_ptr<scalar_t>(), tensor.numel()); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
}); | ||
|
||
} | ||
|
||
} // namespace c10d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters