-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
ProcessGroupNCCL.hpp
726 lines (610 loc) · 28.5 KB
/
ProcessGroupNCCL.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
#pragma once
#include <chrono>
#include <iostream>
#include <list>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <c10d/NCCLUtils.hpp>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/custom_class.h>
namespace c10d {
// Environment variable which controls whether or not wait() is blocking or
// non-blocking.
constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
// Environment variable which controls whether or not we perform Async Error
// Handling with NCCL.
constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING";
// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
// across all processes in the process group. This is the only way that we
// can guarantee to match up the same calls among all processes.
//
// All NCCL functions provided by this class are asynchronous functions. More
// specifically, each NCCL call is scheduled on a separate CUDA stream that is
// different from the current CUDA stream. This is for the purpose of
// achieving potentially concurrency and better performance. As a result,
// it is the callers' responsibility to make sure that the CUDA stream their
// code works on needs to wait for the NCCL operation from
// this class.
//
// This can be done by calling:
//
// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
// functionality and are synonyms.
//
// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
// finished execution on the GPU (not just scheduled).
//
// Example on using the NCCL process group
//
// ProcessGroupNCCL pg(store, rank, size);
// std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors);
//
// // At this point, NCCL kernel has already by queued successfully
// // Now, let current stream wait for the NCCL to finish, this function is
// // async operation as well
//
// work->wait()
//
// // Now continue on other work in the current stream.
class ProcessGroupNCCL : public ProcessGroup {
public:
class WorkNCCL : public ProcessGroup::Work,
public std::enable_shared_from_this<WorkNCCL> {
public:
// Constructor takes a list of CUDA devices
WorkNCCL(const std::vector<at::Device>& devices, int rank, OpType opType, const char* profilingTitle = nullptr);
// Copy constructor doing partial copy without outputs_. Cleanup thread
// monitors and removes finished works. However it will deadlock when
// destructs outputs_ tensors who are view tensors in autograd graph.
WorkNCCL(const WorkNCCL& w);
virtual ~WorkNCCL();
// Checks if request has completed. In this specific case of NCCL, it checks
// if the NCCL operation has completed on the GPU in its own NCCL stream.
// Non-blocking operation.
bool isCompleted() override;
bool isSuccess() const override;
// Same as calling synchronize() for NCCL work.
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
void abort() override;
// Let current stream wait on the completing of the NCCL work
// Throws on exceptions. Blocking operation, which will wait for work
// completion.
void synchronize() override;
// Synchronize streams by blocking each on the NCCL stream
void synchronizeStreams();
// Helper function used in CUDA Stream callbacks to complete WorkNCCL
// objects and throw exceptions when neeeded.
void handleNCCLGuard();
// Helper function that checks if the NCCL kernels have finished
// execution on the GPUs
bool finishedGPUExecution();
// Get a Future object that will be marked as completed internally.
// It actually returns a FutureNCCL object which is a sub class Future.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
// Helper function that sets an exception_ptr on the WorkNCCL object.
void setException(std::exception_ptr exception_ptr);
// Helper function that returns True if the WorkNCCL object has timed out
// and False otherwise.
bool timedOut();
std::vector<at::Tensor> result() override;
protected:
// The cached list of CUDA devices to operate on
std::vector<at::Device> devices_;
// The CUDA events tracking this work item on multiple CUDA devices
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
// The NCCL communicators used for this work item.
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
// Tensors used for barrier op
std::vector<at::Tensor> barrierTensors_;
// Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_ = false;
// Clone of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_;
// Time point representing when the work started.
std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
// Wrapper method for the static checkForNCCLErrors which can be overridden
// for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const;
friend std::ostream& operator<<(
std::ostream& output,
const WorkNCCL& workNCCL);
private:
// Helper function for synchronize
void synchronizeInternal(std::chrono::milliseconds timeout);
// Checks for NCCL errors and sets an appropriate exception_ptr.
void checkAndSetException();
// Checks for NCCL errors and throws an appropriate exception.
void checkAndThrowException();
// Just checks whether GPU execution has completed, without modifying
// exception_ptr.
bool finishedGPUExecutionInternal() const;
// Reference to the store so that we can write aborted communicators
// to the store.
c10::intrusive_ptr<Store> store_;
// Store a reference to NCCL collective's outputs to be used by getFuture.
std::shared_ptr<std::vector<at::Tensor>> outputs_;
friend class ProcessGroupNCCL;
};
struct Options : torch::CustomClassHolder {
explicit Options();
// return intrusive_ptr of the object
static c10::intrusive_ptr<Options> create(
std::chrono::milliseconds timeout = kNoTimeout,
bool isHighStream = false) {
return c10::make_intrusive<Options>();
}
std::chrono::milliseconds opTimeout;
bool isHighPriorityStream;
};
// FutureNCCL is a subclass of ivalue's Future. The goal is to use
// this class in getFuture API of WorkNCCL. This Future is mostly a
// wrapper to synchronize streams appropriately and it mostly enables
// the async programming model of CUDA while trying to adhere to the
// Future interface. FutureNCCL does not support NCCL_BLOCKING_WAIT flag
// or NCCL's barrier().
//
// If created by WorkNCCL's getFuture API, FutureNCCL has a reference to
// WorkNCCL's cudaEvents, NCCL collective's outputs, and the device index of
// outputs' device. Its value is NCCL collective's
// outputs. FutureNCCL only supports single-process single-device mode where
// the size of outputs is equal to 1.
//
// If created by FutureNCCL's then callback, its value becomes the value of
// callback() and its cudaEvents will record the NCCL stream that runs that
// callback. Before invoking the callback, FutureNCCL will synchronize its
// own cudaEvents with the stream that runs the callback. This design
// enables synchronizing the appropriate streams and avoids stalling PyTorch's
// default stream while running the callback. In case of multiple then
// callbacks, each will be executed on its own fresh stream.
struct FutureNCCL : at::ivalue::Future {
public:
explicit FutureNCCL(
at::IValue value,
c10::DeviceIndex deviceIndex,
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
value_(std::move(value)),
deviceIndex_(deviceIndex),
cudaEvents_(std::move(cudaEvents)) {
TORCH_INTERNAL_ASSERT(
cudaEvents_->size() == 1,
"FutureNCCL only supports single-process single-device mode.");
for (const at::cuda::CUDAEvent& event : *cudaEvents_) {
TORCH_INTERNAL_ASSERT(event.isCreated());
TORCH_INTERNAL_ASSERT(event.device_index() == deviceIndex_);
}
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
TORCH_INTERNAL_ASSERT(data_ptr.device().index() == deviceIndex_);
}
}
private:
explicit FutureNCCL(c10::DeviceIndex deviceIndex)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
deviceIndex_(deviceIndex) {}
// We need this because it will be the ::make() static method that actually
// creates the instance. This is a brittle approach and the passkey idiom
// would be a more robust solution. However, this will go away in #48505.
friend c10::intrusive_ptr<FutureNCCL>;
public:
// Gets the current stream of the device and synchronizes recorded streams
// with that. It will return after synchronizing the correct GPU streams to
// ensure we can have async CUDA execution and it does not wait for the
// entire operation to complete on GPU.
void wait() override {
if (error_) {
throw *error_;
}
auto stream = at::cuda::getCurrentCUDAStream(deviceIndex_);
(*cudaEvents_)[0].block(stream);
}
// If FutureNCCL was created by FutureNCCL::then, its value would be empty
// initially. FutureNCCL::then will later use this method to set its value
// to the return value of the callback.
void markCompleted(at::IValue value) override {
TORCH_INTERNAL_ASSERT(
value_.isNone(),
"Attempting to set value of a FutureNCCL which has a value."
"FutureNCCL's value was internally set to NCCL collective's "
"outputs or the return value of the callback.");
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
TORCH_INTERNAL_ASSERT(data_ptr.device().index() == deviceIndex_);
}
value_ = std::move(value);
TORCH_INTERNAL_ASSERT(cudaEvents_ == nullptr);
// Create a new cudaEvents object of size 1 that will record the current
// stream after callback and will be passed to the new FutureNCCL.
cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>(1);
// In case of chained then callback calls, cudaEvents
// records callback's stream.
(*cudaEvents_)[0].record(at::cuda::getCurrentCUDAStream(deviceIndex_));
}
// Just returns FutureNCCL's value after wait returns.
at::IValue value() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}
const at::IValue& constValue() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}
// Adds a callback to FutureNCCL. It invokes the callback inline after
// synchronizing FutureNCCL's own cudaEvents with the stream that runs
// this callback. This new FutureNCCL's cudaEvents will record the
// callback's stream and will have the result value of the callback.
void addCallback(std::function<void(void)> callback) override {
// FIXME Should we find a way to allow to change the priority of streams?
at::cuda::CUDAStream stream =
at::cuda::getStreamFromPool(/*isHighPriority=*/false, deviceIndex_);
// Do not free the underlying data storage of value_ before its
// usage on the stream finishes.
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
c10::cuda::CUDACachingAllocator::recordStream(data_ptr, stream);
}
(*cudaEvents_)[0].block(stream);
// Use the dedicated callback stream to run callback.
c10::StreamGuard streamGuard{stream};
callback();
}
// Adds a callback to FutureNCCL, and returns another FutureNCCL to hold
// the return value of the callback and new cudaEvents that recorded the
// stream that runs this callback.
c10::intrusive_ptr<Future> then(
std::function<at::IValue(void)> callback,
at::TypePtr /* unused */) override {
auto fut = c10::make_intrusive<FutureNCCL>(deviceIndex_);
// The new future needs the DataPtr extractor when it gets marked complete
// but this might happen immediately inline or in parallel by another
// thread. In both these cases this would/might happen before the user has
// time to set their own DataPtr extractor, which might lead to failures
// if the default extractor can't handle some of the user's types.
// Therefore we propagate our extractor.
fut->setDataPtrExtractor(dataPtrExtractor_);
// Cannot move capture std::function in lambda, because it cannot deduce
// the template type for std::function. Hence use std::bind to explicitly
// specify types.
addCallback(std::bind(
[&](std::function<at::IValue(void)> cb) {
try {
fut->markCompleted(at::IValue(cb()));
} catch (const std::exception& e) {
fut->setError(std::current_exception());
}
},
std::move(callback)));
return fut;
}
bool completed() const override {
return true;
}
bool hasValue() const override {
return !value_.isNone();
}
void setDataPtrExtractor(DataPtrExtractor data_ptr_extractor) override {
// To avoid races with other threads that may be using the extractor, we
// won't modify it after it's first set.
if (dataPtrExtractor_ == nullptr) {
dataPtrExtractor_ = std::move(data_ptr_extractor);
}
}
private:
at::IValue value_;
c10::DeviceIndex deviceIndex_;
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
DataPtrExtractor dataPtrExtractor_;
c10::optional<FutureError> error_;
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
const at::IValue& value) {
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs;
if (dataPtrExtractor_ != nullptr) {
// If a Python communication hook is used, dataPtrExtractor_ will be
// set in torch/csrc/jit/python/pybind_utils.h, which allows Python
// dependency to be imported.
data_ptrs = dataPtrExtractor_(value);
} else {
// If a C++ communication hook is used, use the default extractor.
data_ptrs = at::ivalue::Future::defaultDataPtrExtractor(value);
}
TORCH_INTERNAL_ASSERT(data_ptrs.size() == 1, "expected exactly 1 tensor");
return data_ptrs;
}
};
// If you wish to create multiple process groups, each with a potentially
// different rank and size, you can do so by passing a new store instance
// to each one. If you have only a single store object, you can
// use the `c10d::PrefixStore` to derive scoped instances.
// This is also what the Python API in torch.distributed does.
//
// The process group instance keeps a reference to the store because
// it may be used long after the constructor runs. In fact, the constructor
// doesn't create any NCCL communicators. A single NCCL communicator can
// only be used on a specific set of devices, and are therefore created
// on-demand when a collective runs. If another collective is executed later,
// against a different set of devices, the process group creates another NCCL
// communicator. These NCCL communicators are cached and reused if possible.
//
ProcessGroupNCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options = Options::create());
// This constructor includes the deprecated `groupName` argument.
// If you have existing code that uses the `groupName`, you can replace
// it by specifying a `c10d::PrefixStore(groupName, store)` for store.
C10_DEPRECATED ProcessGroupNCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
const std::string& groupName,
c10::intrusive_ptr<Options> options = Options::create())
: ProcessGroupNCCL(store, rank, size, options) {}
virtual ~ProcessGroupNCCL();
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather_base(
at::Tensor& outputbuffer,
at::Tensor& inputbuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
static void groupStart();
static void groupEnd();
// Unsupported Ops
c10::intrusive_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
static const int64_t kProcessGroupNCCLOpTimeoutMillis;
protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(
ncclUniqueId* ncclID,
OpType opType,
const std::string& devicesKey,
int p2pRank);
// Helper that either looks up the cached NCCL communicators or creates
// a new set of NCCL communicators as a cache entry
std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices,
OpType opType,
int p2pRank = 0,
bool isSendRecvSelf = false);
// Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
virtual c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices,
int rank,
OpType opType,
const char* profilingTitle=nullptr);
private:
// Helper that encapsulates work shared across all collective communication
// primitives. The callbacks have the following signatures:
//
// ncclResult_t fn(at::Tensor& input, at::Tensor& output,
// ncclComm_t, at::cuda::CUDAStream&);
// void {pre,post}(std::vector<at::cuda::CUDAStream&>);
template <typename Fn>
c10::intrusive_ptr<ProcessGroup::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
OpType opType,
const char* profilingTitle = nullptr);
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<ProcessGroup::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
PreProcess pre,
PostProcess post,
OpType opType,
const char* profilingTitle = nullptr);
// Helper that encapsulates work shared across point-to-point communication
// primitives. It is the same structure as the helper used for collective
// communicaiton primitives.
template <typename Fn>
c10::intrusive_ptr<ProcessGroup::Work> pointToPoint(
std::vector<at::Tensor>& tensor,
Fn fn,
int peer,
OpType opType);
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<ProcessGroup::Work> pointToPoint(
std::vector<at::Tensor>& tensor,
Fn fn,
int peer,
OpType opType,
PreProcess pre,
PostProcess post);
// Checks for NCCL errors on each of the communicators and returns an
// appropriate exception_ptr (nullptr if no errors).
static std::exception_ptr checkForNCCLErrorsInternal(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
// Function that runs as part of a separate thread and checks for errors on
// NCCL communicators. We need a separate thread to check for NCCL errors
// since we can't rely on the user calling certain methods like wait(),
// isCompleted() etc. to detect and remediate errors. In addition to this, we
// need a mechanism to safely abort and remove NCCL communicators from our
// cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
// class. Attempting to modify the communicator cache from the WorkNCCL class
// might run into issues with object lifetime since the ProcessGroupNCCL
// object might get destroyed before the WorkNCCL object.
void ncclCommWatchdog();
void ncclCommWatchdogInternal();
// This function iterates through the list of WorkNCCL objects in the
// workList_ corresponding to incomplete collectives and then aborts NCCL
// communicators associated with timed out collectives.
void abortTimedOutCollectives(
std::unordered_set<std::string>& abortedCommIds);
void workCleanupLoop();
protected:
static const int64_t kWatchdogThreadSleepMillis;
static const int64_t kWorkCleanupThreadSleepMillis;
// The store is used to broadcast the NCCL unique ID of rank 0.
c10::intrusive_ptr<Store> store_;
// The number of NCCL communicators that have been created during
// the lifetime of this process group. This sequence number is
// used to scope keys used in the store.
uint64_t ncclCommCounter_{0};
// The NCCL communicator that the process group has cached.
//
// For collective operations:
// The key is a list of GPU devices that an operation is operating on
// The GPU devices are stored in a device sequence and the cache NCCL
// communicator is associated with this GPU device sequence
//
// e.g. If the process group op only uses device 0, then the value of
// the used device string stored (value of the hashmap) would be "0".
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately,
// then the value of the used device string (key) stored would be
// "0,1,2,3,4,5,6,7"
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately,
// then the value of the used device string stored would be
// "0,4,5,6,7,1,2,3"
//
// Note that the order of the device for the tensor list matters.
//
// For point-to-point operations:
// The key is a string of my current rank and the peer process rank.
// e.g. If process 1 and process 2 are involved in a point-to-point
// communication, the key will be "1:2" on both processes. Note: this is for
// the scenario where there is only 1 GPU per process. When it comes to
// multiple GPUs per process, this part may need to redesigned.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
devNCCLCommMap_;
// Map from ncclUniqueId to appropriate communicator.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
ncclIdToCommMap_;
// Mutex to guard maps like devNCCLCommMap_ and ncclIdToCommMap_.
std::mutex mutex_;
// Watchdog thread which looks for errors on the cached NCCL communicators.
std::thread ncclCommWatchdogThread_;
// Whether or not we should terminate the watchdog and workCleanup threads.
std::atomic<bool> terminateProcessGroup_;
// Condition variable to control how long the watchdog thread waits.
std::condition_variable watchdogCV_;
// Mutex for watchdog.
std::mutex watchdogCVMutex_;
// Thread that removes NCCL Work upon timeout
std::thread workCleanupThread_;
// Mutex to Guard workMetaList_
std::mutex workMetaListMutex_;
// Condition Variable for timeout thread sleep
std::condition_variable workMetaListCV_;
// Vector to Store WorkNCCL pointers
std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
// Add Work Pointer to workVector
void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>);
// The CUDA steams used by NCCL kernels
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
ncclStreams_;
// The CUDA events used to sync NCCL streams
std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_;
// Device Indexes used for all collectives in this group
std::set<int> usedDeviceIdxs_;
// map from the key: "group name + pg counter (ID)" to the
// unique NCCL ID count. This needs to be group and pg specific
//
// For each process group, we need a uniform unique NCCL ID counter to ensure
// that NCCL operation in this process group can be completed successfully.
// Since each process group ID belongs to a group name, the key to this map
// is a combination of group name and ProcessGroupNCCL ID.
static std::unordered_map<std::string, ssize_t> pgUniqueNCCLIDCnt_;
// map from group name to the pg counter (ID) within that group
//
// For each group with the "group name" (which is the key), we need to
// keep track of a unique process group ID when creating a new
// ProcessGroupNCCL for this "group name". Therefore, the value of this
// map keeps the unique ProcessGroupNCCL's ID for a specific group with
// the "group name". The reason we need a per-group process group ID counter
// is that different group can have different ranks and we need ensure that
// each group has its own uniform process group ID for all its ranks.
static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
// Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete.
bool blockingWait_ = false;
// Whether ot not the workCleanupThread is used to perform async error
// handling.
bool asyncErrorHandling_ = false;
// Timeout for operations. This is only used when blockingWait_ is enabled.
std::chrono::milliseconds opTimeout_;
// Set of communicators that this process group has aborted and their
// ncclUniqueId has been written to the store. We don't need a lock
// for this map since only the watchdog thread accesses this set. The
// set contains the string representation of ncclUniqueId.
std::unordered_set<std::string> abortedComms_;
// Schedule NCCL operations on high priority CUDA streams.
bool isHighPriorityStream_ = false;
// The number of active ncclGroupStart() calls. This counter will be increased
// by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd()
// is called.
static thread_local uint64_t ncclActiveGroupCounter_;
};
} // namespace c10d