-
Notifications
You must be signed in to change notification settings - Fork 21.3k
/
ProcessGroupNCCLTest.cpp
485 lines (412 loc) · 13.6 KB
/
ProcessGroupNCCLTest.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
#include <iostream>
#include <c10d/FileStore.hpp>
#include <c10d/ProcessGroupNCCL.hpp>
#include <c10d/test/CUDATest.hpp>
#include <c10d/test/TestUtils.hpp>
#include <ATen/cuda/CUDAMultiStreamGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/autograd/profiler.h>
#include <gtest/gtest.h>
using namespace c10d::test;
using at::cuda::CUDAStream;
using c10d::ProcessGroup;
class NCCLTestBase {
public:
NCCLTestBase(const std::string& path) : path_(path) {}
NCCLTestBase(NCCLTestBase&& other) {
path_ = std::move(other.path_);
pg_ = std::move(other.pg_);
}
::c10d::ProcessGroupNCCL& getProcessGroup() {
return *pg_;
}
void initialize(int rank, int size) {
auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
new ::c10d::ProcessGroupNCCL(store, rank, size));
}
protected:
std::string path_;
std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
};
class NCCLTest : public NCCLTestBase {
public:
NCCLTest(const std::string& path, int worldSize)
: NCCLTestBase(path),
numDevices_(cudaNumDevices()),
state_(::at::globalContext().lazyInitCUDA()),
worldSize_(worldSize) {
// Each device has a single tensor to perf the NCCL op
tensors_.resize(numDevices_);
inputs_.resize(numDevices_);
outputs_.resize(numDevices_);
at::cuda::OptionalCUDAGuard deviceGuard;
for (auto i = 0; i < numDevices_; ++i) {
deviceGuard.set_index(i);
tensors_[i] = at::empty({3, 3}, at::kCUDA);
inputs_[i].resize(worldSize_ * numDevices_);
outputs_[i].resize(worldSize_ * numDevices_);
for (auto j = 0; j < worldSize_ * numDevices_; ++j) {
inputs_[i][j] = at::empty({3, 3}, at::kCUDA);
outputs_[i][j] = at::empty({3, 3}, at::kCUDA);
}
}
// Allocate a stream per device.
//
// The "current stream" is set globally per device in THC, so we
// can't make two tensors on the same device use different streams
// and pass this along to the collective (since it uses the THC
// getters to retrieve the current stream).
//
streams_.reserve(numDevices_);
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
streams_.push_back(at::cuda::getStreamFromPool());
}
}
void wait(
c10::intrusive_ptr<ProcessGroup::Work>& work,
std::chrono::milliseconds timeout = kNoTimeout) {
at::cuda::CUDAMultiStreamGuard guard(streams_);
work->wait(timeout);
}
std::vector<at::Tensor> getTensors() {
std::vector<at::Tensor> outputs(numDevices_);
// For the duration of this function, make THC use our streams
at::cuda::CUDAMultiStreamGuard guard(streams_);
// Copy inputs to outputs
for (auto i = 0; i < numDevices_; i++) {
cudaStreamSynchronize(streams_[i].stream());
outputs[i] = tensors_[i].cpu();
}
return outputs;
}
std::vector<std::vector<at::Tensor>> getInputTensors() {
return getTensorLists(inputs_);
}
std::vector<std::vector<at::Tensor>> getOutputTensors() {
return getTensorLists(outputs_);
}
int numDevices() const {
return numDevices_;
}
private:
std::vector<std::vector<at::Tensor>> getTensorLists(
std::vector<std::vector<at::Tensor>>& tensor_lists) {
std::vector<std::vector<at::Tensor>> outputs(numDevices_);
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i] = std::vector<at::Tensor>(worldSize_ * numDevices_);
}
// For the duration of this function, make THC use our streams
at::cuda::CUDAMultiStreamGuard guard(streams_);
// Copy inputs to outputs
for (auto i = 0; i < numDevices_; ++i) {
cudaStreamSynchronize(streams_[i].stream());
for (auto j = 0; j < worldSize_ * numDevices_; ++j) {
outputs[i][j] = tensor_lists[i][j].cpu();
}
}
return outputs;
}
protected:
// Launches sleep on every CUDA device
void launchDeviceSleep() {
at::cuda::OptionalCUDAGuard deviceGuard;
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
cudaSleep(streams_[i], 2000 * 1000 * 1000);
}
}
// Launches value initialization for every tensor
void valueInitialization() {
at::cuda::OptionalCUDAGuard deviceGuard;
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
tensors_[i].fill_(pg_->getRank() * numDevices_ + i);
}
}
const int numDevices_;
THCState* state_;
int worldSize_;
std::vector<at::Tensor> tensors_;
std::vector<std::vector<at::Tensor>> inputs_;
std::vector<std::vector<at::Tensor>> outputs_;
std::vector<CUDAStream> streams_;
};
class AllreduceNCCLTest : public NCCLTest {
public:
AllreduceNCCLTest(const std::string& path, int worldSize)
: NCCLTest(path, worldSize) {}
c10::intrusive_ptr<c10d::ProcessGroup::Work> run() {
// For the duration of this function, make THC use our streams
at::cuda::CUDAMultiStreamGuard guard(streams_);
launchDeviceSleep();
valueInitialization();
using namespace torch::autograd::profiler;
// Make sure enabling profile does not make any issue. Note, in single
// process multi-device mode we do not expect any events be populated for
// collective operations, since profiling for that mode is not supported.
enableProfilerLegacy({ProfilerState::CPU});
auto results = pg_->allreduce(tensors_);
disableProfilerLegacy();
return results;
}
};
class BroadcastNCCLTest : public NCCLTest {
public:
BroadcastNCCLTest(const std::string& path, int worldSize)
: NCCLTest(path, worldSize) {}
c10::intrusive_ptr<c10d::ProcessGroup::Work> run(int rootRank, int rootTensor) {
// For the duration of this function, make THC use our streams
at::cuda::CUDAMultiStreamGuard guard(streams_);
launchDeviceSleep();
valueInitialization();
::c10d::BroadcastOptions options;
options.rootRank = rootRank;
options.rootTensor = rootTensor;
return pg_->broadcast(tensors_, options);
}
};
class ReduceNCCLTest : public NCCLTest {
public:
ReduceNCCLTest(const std::string& path, int worldSize)
: NCCLTest(path, worldSize) {}
c10::intrusive_ptr<c10d::ProcessGroup::Work> run(int rootRank, int rootTensor) {
// For the duration of this function, make THC use our streams
at::cuda::CUDAMultiStreamGuard guard(streams_);
launchDeviceSleep();
valueInitialization();
::c10d::ReduceOptions options;
options.rootRank = rootRank;
options.rootTensor = rootTensor;
return pg_->reduce(tensors_, options);
}
};
class AllgatherNCCLTest : public NCCLTest {
public:
AllgatherNCCLTest(const std::string& path, int worldSize)
: NCCLTest(path, worldSize) {}
c10::intrusive_ptr<c10d::ProcessGroup::Work> run() {
// For the duration of this function, make THC use our streams
at::cuda::CUDAMultiStreamGuard guard(streams_);
launchDeviceSleep();
valueInitialization();
return pg_->allgather(outputs_, tensors_);
}
};
struct ReduceScatterNCCLTest : NCCLTest {
ReduceScatterNCCLTest(const std::string& path, int worldSize)
: NCCLTest(path, worldSize) {}
c10::intrusive_ptr<c10d::ProcessGroup::Work> run() {
// For the duration of this function, make THC use our streams
at::cuda::CUDAMultiStreamGuard guard(streams_);
at::cuda::OptionalCUDAGuard deviceGuard;
launchDeviceSleep();
// Launch value initialization for every tensor
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
for (auto j = 0; j < worldSize_ * numDevices_; ++j) {
inputs_[i][j].fill_(
pg_->getRank() * numDevices_ * worldSize_ + i * worldSize_ + j);
}
}
return pg_->reduce_scatter(tensors_, inputs_);
}
};
void testAllreduce(const std::string& path, int rank, int size) {
auto test = AllreduceNCCLTest(path, size);
test.initialize(rank, size);
auto work = test.run();
// Wait for work to finish
test.wait(work);
// Validation
const int totalNumGPUs = test.numDevices() * size;
const auto expected = (totalNumGPUs * (totalNumGPUs - 1)) / 2;
auto tensors = test.getTensors();
for (size_t j = 0; j < tensors.size(); j++) {
auto& tensor = tensors[j];
auto data = tensor.data_ptr<float>();
for (auto k = 0; k < tensor.numel(); k++) {
EXPECT_EQ(data[k], expected)
<< "Allreduce ouputs do not match expected outputs";
}
}
}
void testBroadcast(const std::string& path, int rank, int size) {
auto test = BroadcastNCCLTest(path, size);
test.initialize(rank, size);
const int numDevices = test.numDevices();
// try every permutation of root rank and root tensor
for (auto rootRank = 0; rootRank < size; rootRank++) {
for (auto rootTensor = 0; rootTensor < numDevices; rootTensor++) {
auto work = test.run(rootRank, rootTensor);
// wait for work to complete
test.wait(work);
// Check results
const auto expected = (rootRank * numDevices + rootTensor);
auto tensors = test.getTensors();
for (size_t j = 0; j < tensors.size(); j++) {
auto& tensor = tensors[j];
auto data = tensor.data_ptr<float>();
for (auto k = 0; k < tensor.numel(); k++) {
EXPECT_EQ(data[k], expected)
<< "Broadcast outputs do not match expected outputs";
}
}
}
}
}
void testReduce(const std::string& path, int rank, int size) {
auto test = ReduceNCCLTest(path, size);
test.initialize(rank, size);
const int numDevices = test.numDevices();
// try every permutation of root rank and root tensor
for (auto rootRank = 0; rootRank < size; rootRank++) {
for (auto rootTensor = 0; rootTensor < numDevices; rootTensor++) {
auto work = test.run(rootRank, rootTensor);
// wait for work to complete
test.wait(work);
// Validation
const int totalNumGPUs = numDevices * size;
const auto expected = (totalNumGPUs * (totalNumGPUs - 1)) / 2;
auto tensors = test.getTensors();
if (rank == rootRank) {
auto& tensor = tensors[rootTensor];
auto data = tensor.data_ptr<float>();
for (auto k = 0; k < tensor.numel(); k++) {
EXPECT_EQ(data[k], expected)
<< "Reduce outputs do not match expected outputs";
}
}
}
}
}
void testAllgather(const std::string& path, int rank, int size) {
auto test = AllgatherNCCLTest(path, size);
test.initialize(rank, size);
auto work = test.run();
// Wait for work to finish
test.wait(work);
// Validation
auto tensors = test.getOutputTensors();
// device index
for (size_t i = 0; i < tensors.size(); ++i) {
// rank index
for (size_t j = 0; j < tensors[i].size(); ++j) {
const auto expected = j;
auto& tensor = tensors[i][j];
auto data = tensor.data_ptr<float>();
for (auto k = 0; k < tensor.numel(); k++) {
EXPECT_EQ(data[k], expected)
<< "Allgather outputs do not match expected outputs";
}
}
}
}
void testReduceScatter(const std::string& path, int rank, int size) {
auto test = ReduceScatterNCCLTest(path, size);
test.initialize(rank, size);
auto work = test.run();
// Wait for work to finish
test.wait(work);
const auto participants = test.numDevices() * size;
const auto base = (participants * (participants - 1)) / 2;
// Validation
auto tensors = test.getTensors();
// device index
for (size_t i = 0; i < tensors.size(); ++i) {
const auto modifier = participants * (rank * participants + i);
const auto expected = base + modifier;
auto& tensor = tensors[i];
auto data = tensor.data_ptr<float>();
for (auto j = 0; j < tensor.numel(); j++) {
EXPECT_EQ(data[j], expected) << "ReduceScatter outputs do not match expected outputs!";
}
}
}
class ProcessGroupNCCLTest: public ::testing::Test {
protected:
void SetUp() override {
// Use WORLD_SIZE and RANK environmental variables to do multi-node
// distributed testing
auto sizeEnv = std::getenv("WORLD_SIZE");
auto rankEnv = std::getenv("RANK");
if (sizeEnv && rankEnv) {
size_ = std::stoi(std::string(sizeEnv));
rank_ = std::stoi(std::string(rankEnv));
}
LOG(INFO) << "Multi-node world size: " << size_ << " rank: " << rank_;
}
void TearDown() override {
// Reset NCCL_BLOCKING_WAIT environment variable after each run.
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "0", 1) == 0);
}
bool skipTest() {
// Skip tests if CUDA is not available.
if (!at::cuda::is_available()) {
LOG(INFO) << "CUDA not available, skipping test";
return true;
}
return false;
}
int size_{1};
int rank_{0};
};
TEST_F(ProcessGroupNCCLTest, testAllreduce) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testAllreduce(file.path, rank_, size_);
}
}
TEST_F(ProcessGroupNCCLTest, testBroadcast) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testBroadcast(file.path, rank_, size_);
}
}
TEST_F(ProcessGroupNCCLTest, testReduce) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testReduce(file.path, rank_, size_);
}
}
TEST_F(ProcessGroupNCCLTest, testAllgather) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testAllgather(file.path, rank_, size_);
}
}
TEST_F(ProcessGroupNCCLTest, testReduceScatter) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testReduceScatter(file.path, rank_, size_);
}
}
TEST_F(ProcessGroupNCCLTest, testBackendName) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
auto test = NCCLTestBase(file.path);
test.initialize(rank_, size_);
EXPECT_EQ(
test.getProcessGroup().getBackendName(), std::string(c10d::NCCL_BACKEND_NAME));
}
}