Skip to content

Commit

Permalink
check in (#111875)
Browse files Browse the repository at this point in the history
check in impl

address comments, skip test on rocm

unused
  • Loading branch information
eqy committed Oct 26, 2023
1 parent ab5ea22 commit 2f502cc
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 0 deletions.
32 changes: 32 additions & 0 deletions aten/src/ATen/cuda/CUDAGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>

#include <chrono>
#include <thread>

namespace at::cuda {

static bool _cuda_graphs_debug = false;
constexpr int kSynchronizeBusyWaitMillis = 10;

MempoolId_t graph_pool_handle() {
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
Expand Down Expand Up @@ -55,6 +59,25 @@ CaptureId_t capture_sequence_id() {
* describes memory management for captures.
*/

std::atomic<int> CUDAGraph::pending_event_queries = 0;

// Track any outstanding event queries that could happen e.g., in a NCCL watchdog so that they
// can be resolved before the capture begins. Note that event queries are not allowed during a
// graph capture in the default capture mode.
void CUDAGraph::inc_pending_event_queries() {
pending_event_queries++;
}

void CUDAGraph::dec_pending_event_queries() {
TORCH_INTERNAL_ASSERT(pending_event_queries > 0,
"Attempted to decrement the number of outstanding events to be queried, but it was <= 0.");
pending_event_queries--;
}

int CUDAGraph::num_pending_event_queries() {
return pending_event_queries;
}

CUDAGraph::CUDAGraph()
// CUDAStreams may not be default-constructed.
: capture_stream_(at::cuda::getCurrentCUDAStream()) {
Expand Down Expand Up @@ -115,6 +138,15 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
// due to the capture status being updated _after_ a capture had already started.
c10::cuda::CUDACachingAllocator::beginAllocateStreamToPool(capture_dev_, capture_stream_, mempool_id_);

// At this point, any NCCL watchdogs should be aware that we are in capture mode
// and therefore should not enqueue any additional work that could be event-queried.
// We still must wait on any existing work that has not been cleaned up.
while (num_pending_event_queries()) {
TORCH_WARN_ONCE("Waiting for pending NCCL work to finish before starting graph capture.");
std::this_thread::sleep_for(
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
}

// cudaStreamCaptureModeGlobal is the most conservative option to
// prevent potentially unsafe CUDA API calls during capture. See
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/cuda/CUDAGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAStream.h>

#include <mutex>

namespace at {

struct CUDAGeneratorImpl;
Expand All @@ -19,6 +21,9 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
CUDAGraph();
~CUDAGraph();

static void inc_pending_event_queries();
static void dec_pending_event_queries();
static int num_pending_event_queries();
void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
void capture_end();
void replay();
Expand All @@ -33,6 +38,8 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
cudaGraphExec_t graph_exec_ = NULL;
#endif

static std::atomic<int> pending_event_queries;

// internal states so reset() can do its best cleaning up
// Set to true in capture_end if cudaStreamEndCapture succeeded
// Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
Expand Down
27 changes: 27 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TestCase,
run_tests,
retry_on_connect_failures,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_ROCM,
skip_but_pass_in_sandcastle,
Expand Down Expand Up @@ -457,6 +458,32 @@ def test_allreduce_in_cudagraph(self):
graph.replay()
self.assertEqual(xs[0].item(), 8)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
@skipIfRocm()
def test_nccl_watchdog_cudagraph(self):
# test that the watchdog does not crash graphs with disallowed event query
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
rank = self.rank_to_GPU[self.rank][0]
with torch.cuda.device(rank):
for i in range(100):
xs = [torch.FloatTensor([1]).cuda(rank)]
ys = [torch.FloatTensor([4]).cuda(rank)]
for _ in range(30):
pg.allreduce(xs[0]).wait()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
xs[0] += 0.0
pg.allreduce(xs[0]).wait()
pg.allreduce(xs[0]).wait()
pg.allreduce(xs[0]).wait()
xs[0] += 0.0

for _ in range(1400):
graph.replay()

@requires_nccl()
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
def test_reduce_ops(self):
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <utility>

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGraph.h>
#include <c10/core/DeviceType.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAGuard.h>
Expand Down Expand Up @@ -1021,6 +1022,7 @@ void ProcessGroupNCCL::workCleanupLoop() {
} else {
it = workMetaList_.erase(it);
}
at::cuda::CUDAGraph::dec_pending_event_queries();
} else {
// Increment the iterator if the current WorkNCCL object is not
// completed.
Expand Down Expand Up @@ -1823,8 +1825,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
work->numelIn_ = inputs[0].numel();
work->numelOut_ = outputs[0].numel();

// Notify graphs before we check the capture status preemptively
at::cuda::CUDAGraph::inc_pending_event_queries();

if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
workEnqueue(work);
} else {
at::cuda::CUDAGraph::dec_pending_event_queries();
}

return work;
Expand Down

0 comments on commit 2f502cc

Please sign in to comment.