Skip to content

Commit

Permalink
[c10d] Add an option for NAN check on every collective (#125726)
Browse files Browse the repository at this point in the history
Summary:
The NAN CHECK is done through device side assert without copying needed
from GPU to CPU
Test Plan:
Unit test for collectives that should experience run time error

(sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (38f5143e)]$  python
test/distributed/test_c10d_nccl.py ProcessGroupNCCLTest.test_nan_assert
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)`
failed.
[rank0]:[E507 17:31:56.885473996 Utils.cu:30] CUDA error during
checkForNan: device-side assert triggered

/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)`
failed.
[rank1]:[E507 17:31:56.128961534 Utils.cu:30] CUDA error during
checkForNan: device-side assert triggered

.
----------------------------------------------------------------------
Ran 1 test in 7.723s

OK

Tags:

Pull Request resolved: #125726
Approved by: https://github.com/kwen2501
  • Loading branch information
shuqiangzhang authored and pytorchmergebot committed May 16, 2024
1 parent 0214711 commit c860df5
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 0 deletions.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ cu_library(
name = "torch_cuda",
srcs = [
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
copts = torch_cuda_half_options,
Expand Down Expand Up @@ -830,6 +831,7 @@ cc_library(
"torch/csrc/cuda/python_nccl.cpp",
"torch/csrc/cuda/nccl.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
)) + torch_sources,
Expand Down
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
]
Expand Down
29 changes: 29 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,14 @@ def opts(self, high_priority_stream=False):

def setUp(self):
super().setUp()
# Need to skip return code checking for these tests since the child
# processes don't exit cleanly in some cuda versions
self.skip_return_code_checks = [
self.test_nan_assert_float16.__wrapped__,
self.test_nan_assert_float32.__wrapped__,
self.test_nan_assert_float64.__wrapped__,
]

# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
Expand Down Expand Up @@ -325,6 +333,27 @@ def test_close_pg(self):

del pg

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("type", [torch.float16, torch.float32, torch.float64])
@skip_if_rocm
def test_nan_assert(self, type):
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
device = self.rank_to_GPU[self.rank][0]
size = (10, 10)
nan_tensor = torch.full(size, self.rank, dtype=type, device=device)
# randomly pick an nan element
i = random.randint(0, nan_tensor.size(0) - 1)
j = random.randint(0, nan_tensor.size(1) - 1)
nan_tensor[i, j] = float("nan")
with self.assertRaises(RuntimeError):
pg.allreduce(nan_tensor)
dist.destroy_process_group()
# reset env
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_destruct_before_terminate_pg(self):
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
// both timeout and other errors.
dumpOnException_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
(dist_debug_level_ >= DebugLevel::Detail);
enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false);
heartbeat_ = 1ULL;
monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
heartbeatTimeoutInSec_ =
Expand Down Expand Up @@ -836,6 +837,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
<< ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_
<< ", PG Name: " << options_->group_name;

if (options_->global_ranks_in_group.empty()) {
Expand Down Expand Up @@ -2424,6 +2426,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
OpType opType,
const char* profilingTitle,
bool avoidRecordStreams) {
if (enableNanCheck_) {
checkForNan(input);
}
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
c10::cuda::CaptureStatus capture_status =
Expand Down Expand Up @@ -2779,6 +2784,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
PreProcess pre,
PostProcess post,
const char* profilingTitle) {
if (enableNanCheck_) {
checkForNan(tensor);
}
// avoidRecordStreams_ note:
// send, recv, and irecv should be ok with avoidRecordStreams,
// However, for isend, I don't think the API requires the user
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
"TORCH_NCCL_COORD_CHECK_MILSEC"};

static std::vector<std::string> TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"};

constexpr const char* NCCL_BACKEND_NAME = "nccl";

constexpr const char* EXCEPTION_DUMP = "exception_dump";
Expand Down Expand Up @@ -1024,6 +1026,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// timeout and nccl errors.
bool dumpOnException_;

// Whether or not to enable nan check for input tensors to collectives.
bool enableNanCheck_;

// Whether or not to create start CUDAEvent and enable timing for start
// and end events. Note that enableTiming_ is always true if desyncDebug_
// is set to true.
Expand Down
45 changes: 45 additions & 0 deletions torch/csrc/distributed/c10d/Utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/torch.h>
#include <algorithm>

namespace c10d {

// CUDA kernel to check if data has NAN, device side assert
// is raised if NAN is found
template <typename T>
__global__ void checkForNaN(T* data, size_t size) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;

for (size_t i = tid; i < size; i += stride) {
CUDA_KERNEL_ASSERT(!isnan(data[i]));
}
}

// CHECK if a Tensor contains NAN in any of its element
void checkForNan(const at::Tensor& tensor) {
// skip check for non float types
if (!torch::is_floating_point(tensor)) {
return;
}
const size_t maxNumThreadsPerBlock = 256;
const size_t maxNumBlocks = 24;
const size_t numThreadsPerBlock =
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel());

const size_t numBlocks = std::min<size_t>(
maxNumBlocks,
(tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor.scalar_type(), "checkForNaN", [&] {
checkForNaN<scalar_t><<<numBlocks, numThreadsPerBlock>>>(
tensor.data_ptr<scalar_t>(), tensor.numel());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});

}

} // namespace c10d
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/Utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,8 @@ using SizeType = uint64_t;
// Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1
#define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1)

void checkForNan(const at::Tensor& tensor);

namespace tcputil {

// Send and receive
Expand Down

0 comments on commit c860df5

Please sign in to comment.