-
Notifications
You must be signed in to change notification settings - Fork 22.6k
/
ProcessGroupNCCL.hpp
354 lines (291 loc) · 13.8 KB
/
ProcessGroupNCCL.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
#pragma once
#include <mutex>
#include <thread>
#include <unordered_map>
#include <c10d/NCCLUtils.hpp>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
namespace c10d {
// Environment variable which controls whether or not wait() is blocking or
// non-blocking.
constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
// across all processes in the process group. This is the only way that we
// can guarantee to match up the same calls among all processes.
//
// All NCCL functions provided by this class are asynchronous functions. More
// specifically, each NCCL call is scheduled on a separate CUDA stream that is
// different from the current CUDA stream. This is for the purpose of
// achieving potentially concurrency and better performance. As a result,
// it is the callers' responsibilty to make sure that the CUDA stream their
// code works on needs to wait for the NCCL operation from
// this class.
//
// This can be done by calling:
//
// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
// functionality and are synonyms.
//
// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
// finished execution on the GPU (not just scheduled).
//
// Example on using the NCCL process group
//
// ProcessGroupNCCL pg(store, rank, size);
// std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors);
//
// // At this point, NCCL kernel has already by queued successfully
// // Now, let current stream wait for the NCCL to finish, this function is
// // async operation as well
//
// work->wait()
//
// // Now continue on other work in the current stream.
class ProcessGroupNCCL : public ProcessGroup {
public:
class WorkNCCL : public ProcessGroup::Work {
public:
// Constructor takes a list of CUDA devices
WorkNCCL(const std::vector<at::Device>& devices);
virtual ~WorkNCCL();
// Checks if request has completed. In this specific case of NCCL, it checks
// if the NCCL operation has completed on the GPU in its own NCCL stream.
// Non-blocking operation.
bool isCompleted() override;
bool isSuccess() const override;
// Same as calling synchronize() for NCCL work.
void wait() override;
// Let current stream wait on the completing of the NCCL work
// Throws on exceptions. Blocking operation, which will wait for work
// completion.
void synchronize() override;
// Helper function that checks if the NCCL kernels have finished
// execution on the GPUs
bool finishedGPUExecution();
protected:
// The cached list of CUDA devices to operate on
std::vector<at::Device> devices_;
// The CUDA events tracking this work item on multiple CUDA devices
std::vector<at::cuda::CUDAEvent> cudaEvents_;
// The NCCL communicators used for this work item.
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
// Tensors used for barrier op
std::vector<at::Tensor> barrierTensors_;
// Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_ = false;
// Clonge of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_;
// Time point representing when the work started.
std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
// Wrapper method for the static checkForNCCLErrors which can be overridden
// for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const;
private:
// Checks for NCCL errors and sets an appropriate exception_ptr.
void checkAndSetException();
// Checks for NCCL errors and throws an appropriate exception.
void checkAndThrowException();
// Just checks whether GPU execution has completed, without modifying
// exception_ptr.
bool finishedGPUExecutionInternal() const;
friend class ProcessGroupNCCL;
};
// If you wish to create multiple process groups, each with a potentially
// different rank and size, you can do so by passing a new store instance
// to each one. If you have only a single store object, you can
// use the `c10d::PrefixStore` to derive scoped instances.
// This is also what the Python API in torch.distributed does.
//
// The process group instance keeps a reference to the store because
// it may be used long after the constructor runs. In fact, the constructor
// doesn't create any NCCL communicators. A single NCCL communicator can
// only be used on a specific set of devices, and are therefore created
// on-demand when a collective runs. If another collective is executed later,
// against a different set of devices, the process group creates another NCCL
// communicator. These NCCL communicators are cached and reused if possible.
//
ProcessGroupNCCL(
const std::shared_ptr<Store>& store,
int rank,
int size,
const std::chrono::milliseconds& opTimeout =
std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis));
// This constructor includes the deprecated `groupName` argument.
// If you have existing code that uses the `groupName`, you can replace
// it by specifying a `c10d::PrefixStore(groupName, store)` for store.
C10_DEPRECATED ProcessGroupNCCL(
const std::shared_ptr<Store>& store,
int rank,
int size,
const std::string& groupName,
const std::chrono::milliseconds& opTimeout =
std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis))
: ProcessGroupNCCL(store, rank, size, opTimeout) {}
virtual ~ProcessGroupNCCL();
std::shared_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
std::shared_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
std::shared_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
std::shared_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
// Unsupported Ops
std::shared_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
std::shared_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
std::shared_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
std::shared_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
static const int64_t kProcessGroupNCCLOpTimeoutMillis;
protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(ncclUniqueId* ncclID);
// Helper that either looks up the cached NCCL communicators or creates
// a new set of NCCL communicators as a cache entry
std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices);
// Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
virtual std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices);
private:
// Helper that encapsulates work shared across all collective communication
// primitives. The callbacks have the following signatures:
//
// ncclResult_t fn(at::Tensor& input, at::Tensor& output,
// ncclComm_t, at::cuda::CUDAStream&);
// void {pre,post}(std::vector<at::cuda::CUDAStream&>);
template <typename Fn>
std::shared_ptr<ProcessGroup::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn);
template <typename Fn, typename PreProcess, typename PostProcess>
std::shared_ptr<ProcessGroup::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
PreProcess pre,
PostProcess post);
// Checks for NCCL errors on each of the communicators and returns an
// appropriate exception_ptr (nullptr if no errors).
static std::exception_ptr checkForNCCLErrorsInternal(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
// Function that runs as part of a separate thread and checks for errors on
// NCCL communicators. We need a separate thread to check for NCCL errors
// since we can't rely on the user calling certain methods like wait(),
// isCompleted() etc. to detect and remediate errors. In addition to this, we
// need a mechanism to safely abort and remove NCCL communicators from our
// cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
// class. Attempting to modify the communicator cache from the WorkNCCL class
// might run into issues with object lifetime since the ProcessGroupNCCL
// object might get destroyed before the WorkNCCL object.
void ncclCommWatchdog();
protected:
static const int64_t kWatchdogThreadSleepMillis;
// The store is used to broadcast the NCCL unique ID of rank 0.
std::shared_ptr<Store> store_;
// The number of NCCL communicators that have been created during
// the lifetime of this process group. This sequence number is
// used to scope keys used in the store.
uint64_t ncclCommCounter_{0};
// The NCCL communicator that the process group has cached.
// The key is a list of GPU devices that an operation is operating on
// The GPU devices are stored in a device sequence and the cache NCCL
// communicator is associated with this GPU device sequence
//
// e.g. If the process group op only uses device 0, then the value of
// the used device string stored (value of the hashmap) would be "0".
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately,
// then the value of the used device string (key) stored would be
// "0,1,2,3,4,5,6,7"
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately,
// then the value of the used device string stored would be
// "0,4,5,6,7,1,2,3"
//
// Note that the order of the device for the tensor list matters.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
devNCCLCommMap_;
// Mutex to guard devNCCLCommMap_.
std::mutex devNCCLCommMapLock_;
// Watchdog thread which looks for errors on the cached NCCL communicators.
std::thread ncclCommWatchdogThread_;
// Whether or not we should terminate the watchdog thread.
std::atomic<bool> terminateWatchdog_;
// Condition variable to control how long the watchdog thread waits.
std::condition_variable watchdogCV_;
// Mutex for watchdog.
std::mutex watchdogCVMutex_;
// The CUDA steams used by NCCL kernels
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
ncclStreams_;
// The CUDA events used to sync NCCL streams
std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_;
// Device Indexes used for all collectives in this group
std::set<int> usedDeviceIdxs_;
// map from the key: "group name + pg counter (ID)" to the
// unique NCCL ID count. This needs to be group and pg specific
//
// For each process group, we need a uniform unique NCCL ID counter to ensure
// that NCCL operation in this process group can be completed successfully.
// Since each process group ID belongs to a group name, the key to this map
// is a combination of group name and ProcessGroupNCCL ID.
static std::unordered_map<std::string, ssize_t> pgUniqueNCCLIDCnt_;
// map from group name to the pg counter (ID) within that group
//
// For each group with the "group name" (which is the key), we need to
// keep track of a unique process group ID when creating a new
// ProcessGroupNCCL for this "group name". Therefore, the value of this
// map keeps the unique ProcessGroupNCCL's ID for a specific group with
// the "group name". The reason we need a per-group process group ID counter
// is that different group can have different ranks and we need ensure that
// each group has its own uniform process group ID for all its ranks.
static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
// Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete.
bool blockingWait_ = false;
// Timeout for operations. This is only used when blockingWait_ is enabled.
std::chrono::milliseconds opTimeout_;
};
} // namespace c10d