-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
ProcessGroupGloo.cpp
2633 lines (2281 loc) · 81.6 KB
/
ProcessGroupGloo.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <c10d/ProcessGroupGloo.hpp>
#include <c10d/GlooDeviceFactory.hpp>
#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#include <gloo/common/win.h>
#else
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#endif
#include <sys/types.h>
#include <type_traits>
#include <gloo/allgather.h>
#include <gloo/allgatherv.h>
#include <gloo/allreduce.h>
#include <gloo/alltoall.h>
#include <gloo/alltoallv.h>
#include <gloo/barrier.h>
#include <gloo/broadcast.h>
#include <gloo/gather.h>
#include <gloo/reduce.h>
#include <gloo/scatter.h>
#include <ATen/SparseTensorUtils.h>
#ifdef USE_CUDA
#include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/PinnedMemoryAllocator.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#endif
#include <c10/util/StringUtil.h>
#include <gloo/config.h>
#include <gloo/rendezvous/context.h>
#include <gloo/rendezvous/prefix_store.h>
#ifdef _WIN32
#define GENERATE_ALL_TYPES(type, func, ...) \
switch (type) { \
case ::at::ScalarType::Float: \
func<float>(__VA_ARGS__); \
break; \
case ::at::ScalarType::Double: \
func<double>(__VA_ARGS__); \
break; \
case ::at::ScalarType::Half: \
func<gloo::float16>(__VA_ARGS__); \
break; \
case ::at::ScalarType::Char: \
func<int8_t>(__VA_ARGS__); \
break; \
case ::at::ScalarType::Byte: \
func<uint8_t>(__VA_ARGS__); \
break; \
case ::at::ScalarType::Int: \
func<int32_t>(__VA_ARGS__); \
break; \
case ::at::ScalarType::Long: \
func<int64_t>(__VA_ARGS__); \
break; \
default: \
throw std::runtime_error("Invalid scalar type"); \
}
#define HOST_NAME_MAX 256
#else
#define GENERATE_ALL_TYPES(type, func, args...) \
switch (type) { \
case ::at::ScalarType::Float: \
func<float>(args); \
break; \
case ::at::ScalarType::Double: \
func<double>(args); \
break; \
case ::at::ScalarType::Half: \
func<gloo::float16>(args); \
break; \
case ::at::ScalarType::Char: \
func<int8_t>(args); \
break; \
case ::at::ScalarType::Byte: \
func<uint8_t>(args); \
break; \
case ::at::ScalarType::Int: \
func<int32_t>(args); \
break; \
case ::at::ScalarType::Long: \
func<int64_t>(args); \
break; \
default: \
throw std::runtime_error("Invalid scalar type"); \
}
#endif
namespace c10d {
namespace {
// Wrap c10d store as Gloo store
class GlooStore : public ::gloo::rendezvous::Store {
public:
GlooStore(const std::shared_ptr<::c10d::Store>& store) : store_(store) {}
void set(const std::string& key, const std::vector<char>& value) override {
std::vector<uint8_t> tmp(value.begin(), value.end());
store_->set(key, tmp);
}
std::vector<char> get(const std::string& key) override {
auto value = store_->get(key);
return std::vector<char>(value.begin(), value.end());
}
void wait(const std::vector<std::string>& keys) override {
store_->wait(keys, Store::kDefaultTimeout);
}
void wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) override {
store_->wait(keys, timeout);
}
protected:
std::shared_ptr<::c10d::Store> store_;
};
typedef void (*ReduceFunc)(void*, const void*, const void*, size_t);
template <
typename T,
typename std::enable_if<!std::is_integral<T>::value, int>::type = 0>
ReduceFunc toFunction(const ReduceOp& r) {
switch (r) {
case ReduceOp::SUM:
return ReduceFunc(&::gloo::sum<T>);
case ReduceOp::PRODUCT:
return ReduceFunc(&::gloo::product<T>);
case ReduceOp::MIN:
return ReduceFunc(&::gloo::min<T>);
case ReduceOp::MAX:
return ReduceFunc(&::gloo::max<T>);
case ReduceOp::BAND:
throw std::runtime_error(
"Cannot use ReduceOp.BAND with non-integral dtype");
break;
case ReduceOp::BOR:
throw std::runtime_error(
"Cannot use ReduceOp.BOR with non-integral dtype");
break;
case ReduceOp::BXOR:
throw std::runtime_error(
"Cannot use ReduceOp.BXOR with non-integral dtype");
break;
case ReduceOp::UNUSED:
break;
}
throw std::runtime_error("Unhandled ReduceOp");
}
// Bitwise AND with SFINAE guard for integral types.
template <
typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
void band(void* c, const void* a, const void* b, size_t n) {
auto tc = static_cast<T*>(c);
auto ta = static_cast<const T*>(a);
auto tb = static_cast<const T*>(b);
for (size_t i = 0; i < n; i++) {
tc[i] = ta[i] & tb[i];
}
}
// Bitwise OR with SFINAE guard for integral types.
template <
typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
void bor(void* c, const void* a, const void* b, size_t n) {
auto tc = static_cast<T*>(c);
auto ta = static_cast<const T*>(a);
auto tb = static_cast<const T*>(b);
for (size_t i = 0; i < n; i++) {
tc[i] = ta[i] | tb[i];
}
}
// Bitwise XOR with SFINAE guard for integral types.
template <
typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
void bxor(void* c, const void* a, const void* b, size_t n) {
auto tc = static_cast<T*>(c);
auto ta = static_cast<const T*>(a);
auto tb = static_cast<const T*>(b);
for (size_t i = 0; i < n; i++) {
tc[i] = ta[i] ^ tb[i];
}
}
template <
typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
ReduceFunc toFunction(const ReduceOp& r) {
switch (r) {
case ReduceOp::SUM:
return ReduceFunc(&::gloo::sum<T>);
case ReduceOp::PRODUCT:
return ReduceFunc(&::gloo::product<T>);
case ReduceOp::MIN:
return ReduceFunc(&::gloo::min<T>);
case ReduceOp::MAX:
return ReduceFunc(&::gloo::max<T>);
case ReduceOp::BAND:
return ReduceFunc(&band<T>);
case ReduceOp::BOR:
return ReduceFunc(&bor<T>);
case ReduceOp::BXOR:
return ReduceFunc(&bxor<T>);
case ReduceOp::UNUSED:
break;
}
throw std::runtime_error("Unhandled ReduceOp");
}
template <typename T, typename O>
void setInputs(O& opts, std::vector<at::Tensor>& tensors) {
opts.setInputs(getDataPointers<T>(tensors), tensors[0].numel());
}
template <typename T, typename O>
void setInput(O& opts, at::Tensor& tensor) {
opts.setInput(getDataPointer<T>(tensor), tensor.numel());
}
template <typename T, typename O>
void setInput(O& opts, at::Tensor& tensor, std::vector<size_t>& counts) {
opts.setInput(getDataPointer<T>(tensor), counts);
}
template <typename T, typename O>
void setInput(O& opts, at::Tensor& tensor, std::vector<int64_t>& counts) {
opts.setInput(getDataPointer<T>(tensor), counts);
}
template <typename T, typename O>
void setOutputs(O& opts, std::vector<at::Tensor>& tensors) {
opts.setOutputs(getDataPointers<T>(tensors), tensors[0].numel());
}
template <typename T, typename O>
void setOutput(O& opts, at::Tensor& tensor) {
opts.setOutput(getDataPointer<T>(tensor), tensor.numel());
}
template <typename T, typename O>
void setOutput(O& opts, at::Tensor& tensor, std::vector<size_t>& counts) {
opts.setOutput(getDataPointer<T>(tensor), counts);
}
template <typename T, typename O>
void setOutput(O& opts, at::Tensor& tensor, std::vector<int64_t>& counts) {
opts.setOutput(getDataPointer<T>(tensor), counts);
}
#ifdef USE_CUDA
at::Tensor pinnedLike(at::Tensor& tensor) {
auto* allocator = at::cuda::getPinnedMemoryAllocator();
auto storage = c10::Storage(
c10::Storage::use_byte_size_t(),
at::detail::computeStorageNbytes(
tensor.sizes(), tensor.strides(), tensor.dtype().itemsize()),
allocator,
/*resizable=*/false);
return at::empty({0}, tensor.options().device(at::kCPU))
.set_(storage, 0, tensor.sizes(), tensor.strides());
}
// This function initializes a vector of CUDA streams, one for every
// tensor in the input tensor vector, and ensures that these streams are
// synchronized with the current default streams. This is needed so
// that new work on the new streams is serialized w.r.t. all operations
// on the tensors.
void initializeStreamsEvents(
const std::vector<at::Tensor>& tensors,
std::vector<at::cuda::CUDAStream>& streams,
std::vector<at::cuda::CUDAEvent>& events) {
at::cuda::OptionalCUDAGuard guard;
streams.reserve(tensors.size());
events.resize(tensors.size());
for (size_t i = 0; i < tensors.size(); i++) {
guard.set_index(tensors[i].device().index());
// Record event on current stream
events[i].record(at::cuda::getCurrentCUDAStream());
// Get a non-default stream to execute asynchronous CUDA operations
// on for this device. This ensures that the default stream used
// by the caller is not occupied by c10d related operations.
streams.push_back(at::cuda::getStreamFromPool(
/* isHighPriority */ true, tensors[i].device().index()));
// Ensure the new stream is synchronized with the current stream.
events[i].block(streams[i]);
// `tensors` are created on a different stream. Hence, they must record
// new streams in this Work to prevent being freed before the Work finishes.
if (tensors[i].is_sparse()) {
if (tensors[i].is_coalesced()) {
c10::cuda::CUDACachingAllocator::recordStream(
tensors[i].indices().storage().data_ptr(), streams[i]);
c10::cuda::CUDACachingAllocator::recordStream(
tensors[i].values().storage().data_ptr(), streams[i]);
} else {
// We will need to coalesce first, which means new tensors will
// be allocated on the streams we just allocated, and there
// is no need to record them separately.
}
} else {
c10::cuda::CUDACachingAllocator::recordStream(
tensors[i].storage().data_ptr(), streams[i]);
}
}
}
// This function initializes a vector of CUDA streams, one per device,
// and ensures that these streams are synchronized with the current default
// streams. It is assumed that the tensors in the nested tensor vectors are
// on the same device.
void initializeStreamsEvents(
std::vector<std::vector<at::Tensor>>& tensors,
std::vector<at::cuda::CUDAStream>& streams,
std::vector<at::cuda::CUDAEvent>& events) {
// Ensure that the tensors in the nested tensor vectors are on the same
// device.
for (size_t i = 0; i < tensors.size(); i++) {
auto device_id = tensors[i][0].device().index();
for (size_t j = 1; j < tensors[i].size(); j++) {
if (tensors[i][j].device().index() != device_id) {
throw std::runtime_error(
"tensors in the nested tensor vectors need to "
"be on the same device");
}
}
}
at::cuda::OptionalCUDAGuard guard;
streams.reserve(tensors.size());
events.resize(tensors.size());
for (size_t i = 0; i < tensors.size(); i++) {
guard.set_index(tensors[i][0].device().index());
// Record event on current stream
events[i].record(at::cuda::getCurrentCUDAStream());
// Get a non-default stream to execute asynchronous CUDA operations
// on for this output. This ensures that the default stream used
// by the caller is not occupied by c10d related operations.
streams.push_back(at::cuda::getStreamFromPool(
/* isHighPriority */ true, tensors[i][0].device().index()));
// Ensure the new stream is synchronized with the current stream.
events[i].block(streams[i]);
for (at::Tensor& tensor : tensors[i]) {
// `tensors` are created on a different stream. Hence, they must record
// new streams in this Work to prevent being freed before the Work
// finishes.
c10::cuda::CUDACachingAllocator::recordStream(
tensor.storage().data_ptr(), streams[i]);
}
}
}
#endif
const auto kLoopbackAddress = "127.0.0.1";
} // namespace
ProcessGroupGloo::SendWork::SendWork(
at::Tensor& tensor,
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer)
: tensor_(tensor), buffer_(std::move(buffer)) {}
bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) {
bool sendCompleted = false;
std::exception_ptr exception{nullptr};
try {
if (timeout == kNoTimeout) {
sendCompleted = buffer_->waitSend();
} else {
sendCompleted = buffer_->waitSend(timeout);
}
} catch (...) {
exception = std::current_exception();
}
// Completes the Work object and throws the exception.
finishAndThrow(exception);
return sendCompleted;
}
void ProcessGroupGloo::SendWork::abort() {
buffer_->abortWaitSend();
}
ProcessGroupGloo::RecvWork::RecvWork(
at::Tensor& tensor,
std::unique_ptr<::gloo::transport::UnboundBuffer> buffer)
: tensor_(tensor), buffer_(std::move(buffer)), srcRank_(-1) {}
int ProcessGroupGloo::RecvWork::sourceRank() const {
std::lock_guard<std::mutex> lock(mutex_);
return srcRank_;
}
bool ProcessGroupGloo::RecvWork::wait(std::chrono::milliseconds timeout) {
bool recvCompleted = false;
std::exception_ptr exception{nullptr};
try {
if (timeout == kNoTimeout) {
recvCompleted = buffer_->waitRecv(&srcRank_);
} else {
recvCompleted = buffer_->waitRecv(&srcRank_, timeout);
}
} catch (...) {
exception = std::current_exception();
}
// Completes the Work object and throws the exception.
finishAndThrow(exception);
return recvCompleted;
}
void ProcessGroupGloo::RecvWork::abort() {
buffer_->abortWaitRecv();
}
ProcessGroupGloo::Options::Options()
: timeout(std::chrono::milliseconds(10 * 1000)), threads(2) {}
namespace {
void socketInitialize() {
#ifdef _WIN32
::gloo::init_winsock();
#endif
}
// Gloo assumes that this machine's hostname can always be resolved
// to an address. If it doesn't it throws a runtime error saying
// that it can't be resolved. Instead of catching it, we choose
// to proactively check if an address can be resolved, so we can
// gracefully fall back to an alternative if it doesn't.
bool doesHostnameResolveToUsableAddress(const std::string& hostname) {
socketInitialize();
struct addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
struct addrinfo* result;
auto rv = getaddrinfo(hostname.c_str(), nullptr, &hints, &result);
if (rv < 0) {
return false;
}
struct addrinfo* rp;
for (rp = result; rp != nullptr; rp = rp->ai_next) {
auto fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (fd == -1) {
continue;
}
rv = bind(fd, rp->ai_addr, rp->ai_addrlen);
#ifdef _WIN32
closesocket(fd);
#else
close(fd);
#endif
if (rv == -1) {
continue;
}
break;
}
freeaddrinfo(result);
return rp != nullptr;
}
} // namespace
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDeviceForInterface(const std::string& interface_name) {
return ::c10d::GlooDeviceFactory::makeDeviceForInterface(interface_name);
}
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDeviceForHostname(const std::string& hostname) {
TORCH_CHECK(
doesHostnameResolveToUsableAddress(hostname),
"Cannot resolve ",
hostname,
" to a (local) address");
return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname);
}
#if defined(__linux__) || defined(_WIN32)
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDefaultDevice() {
// Use the hostname to resolve the network address to
// use. Note: if the hostname does not resolve to an address (e.g.
// because of misconfigured /etc/hosts file), this will not work.
socketInitialize();
std::array<char, HOST_NAME_MAX> hostname{};
auto rv = gethostname(hostname.data(), HOST_NAME_MAX);
if (rv != 0) {
throw std::system_error(errno, std::system_category());
}
// Use this machine's hostname if it resolves to an address.
if (doesHostnameResolveToUsableAddress(hostname.data())) {
return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.data());
}
// Otherwise, use the loopback address.
TORCH_WARN_ONCE(
"Unable to resolve hostname to a (local) address. ",
"Using the loopback address as fallback. ",
"Manually set the network interface to bind to with GLOO_SOCKET_IFNAME.");
return createDeviceForHostname(kLoopbackAddress);
}
#endif
#ifdef __APPLE__
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDefaultDevice() {
// Use the hostname to resolve the network address to
// use. Note: if the hostname does not resolve to an address (e.g.
// because of misconfigured /etc/hosts file), this will not work.
const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX);
auto hostname = std::unique_ptr<char[]>(new char[hostNameMax]);
auto rv = gethostname(hostname.get(), hostNameMax);
if (rv != 0) {
throw std::system_error(errno, std::system_category());
}
// Use this machine's hostname if it resolves to an address.
if (doesHostnameResolveToUsableAddress(hostname.get())) {
return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.get());
}
// Otherwise, use the loopback address.
TORCH_WARN_ONCE(
"Unable to resolve hostname to a (local) address. ",
"Using the loopback address as fallback. ",
"Manually set the network interface to bind to with GLOO_SOCKET_IFNAME.");
return createDeviceForHostname(kLoopbackAddress);
}
#endif
ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<Store>& store,
int rank,
int size,
Options options)
: ProcessGroup(rank, size),
store_(new GlooStore(store)),
stop_(false),
collectiveCounter_(0) {
auto& devices = options.devices;
if (devices.empty()) {
throw std::runtime_error("No device(s) specified");
}
// Create and connect a context for every device.
//
// Note that the same device can be specified multiple times, either
// the same object, or the same logical device as different objects.
// Either mode is fine and only has performance implications.
//
// Using the same object multiple times means all contexts share a
// single I/O thread. If you use different objects for the same
// logical device they will have independent I/O threads. The latter
// option is needed if you have a fast NIC that cannot be saturated
// by a single I/O thread.
//
contexts_.reserve(options.devices.size());
for (size_t i = 0; i < options.devices.size(); i++) {
auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_);
context->setTimeout(options.timeout);
context->connectFullMesh(store, options.devices[i]);
contexts_.push_back(std::move(context));
}
// Every worker thread stores the AsyncWork object it's currently
// working on in the workInProgress_ vector. It must have size equal
// to the number of workers such that they can simply index into it
// using the worker index they are started with.
workInProgress_.resize(options.threads);
threads_.resize(options.threads);
for (size_t i = 0; i < threads_.size(); i++) {
threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i);
}
}
ProcessGroupGloo::~ProcessGroupGloo() {
std::unique_lock<std::mutex> lock(workMutex_);
workConsumeCV_.wait(lock, [&] { return workQueue_.empty(); });
// Queue is empty, signal stop
stop_ = true;
// Release lock to allow threads to terminate
lock.unlock();
workProduceCV_.notify_all();
// Wait for worker threads to terminate
for (auto& thread : threads_) {
thread.join();
}
}
uint32_t ProcessGroupGloo::nextTag() {
return collectiveCounter_++;
}
std::shared_ptr<::gloo::Context> ProcessGroupGloo::getContext(uint32_t tag) {
return contexts_[tag % contexts_.size()];
}
void ProcessGroupGloo::runLoop(int workerIndex) {
std::unique_lock<std::mutex> lock(workMutex_);
while (!stop_) {
if (workQueue_.empty()) {
workProduceCV_.wait(lock);
continue;
}
auto work = std::move(workQueue_.front());
workQueue_.pop_front();
workInProgress_[workerIndex] = work;
lock.unlock();
// Notify after releasing the lock so that the waiter
// does not immediately block.
workConsumeCV_.notify_one();
AsyncWork::execute(std::move(work));
lock.lock();
workInProgress_[workerIndex] = nullptr;
}
}
void ProcessGroupGloo::enqueue(std::shared_ptr<AsyncWork> work) {
std::unique_lock<std::mutex> lock(workMutex_);
workQueue_.push_back(std::move(work));
lock.unlock();
// Notify after releasing the lock so that the waiter
// does not immediately block.
workProduceCV_.notify_one();
}
namespace {
class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork {
public:
AsyncBroadcastWork(
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
int rootRank,
int rootTensor,
uint32_t tag)
: ProcessGroupGloo::AsyncWork("gloo:broadcast"),
context(context),
inputs(inputs),
rootRank(rootRank),
rootTensor(rootTensor),
tag(tag) {}
std::shared_ptr<gloo::Context> context;
std::vector<at::Tensor> inputs;
const int rootRank;
const int rootTensor;
const uint32_t tag;
void broadcast(at::Tensor& tensor) {
const auto& scalarType = tensor.scalar_type();
gloo::BroadcastOptions opts(context);
opts.setRoot(rootRank);
opts.setTag(tag);
GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensor);
gloo::broadcast(opts);
}
void run() override {
broadcast(inputs[rootTensor]);
// Copy to non-root tensors
for (size_t i = 0; i < inputs.size(); i++) {
if (i == static_cast<size_t>(rootTensor)) {
continue;
}
inputs[i].copy_(inputs[rootTensor]);
}
}
};
#ifdef USE_CUDA
class AsyncBroadcastCUDAWork : public AsyncBroadcastWork {
public:
AsyncBroadcastCUDAWork(
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
int rootRank,
int rootTensor,
uint32_t tag)
: AsyncBroadcastWork(context, inputs, rootRank, rootTensor, tag) {
initializeStreamsEvents(inputs, streams, events);
// Create pinned host side tensors.
tmp = pinnedLike(inputs[rootTensor]);
at::cuda::OptionalCUDAStreamGuard guard;
if (context->rank == rootRank) {
guard.reset_stream(streams[rootTensor]);
tmp.copy_(inputs[rootTensor], /* non_blocking */ true);
}
}
void run() override {
at::cuda::OptionalCUDAStreamGuard guard;
// Synchronize with copy operation if applicable.
if (context->rank == rootRank) {
guard.reset_stream(streams[rootTensor]);
AT_CUDA_CHECK(cudaStreamSynchronize(streams[rootTensor]));
}
// Run broadcast on host side tensors.
broadcast(tmp);
// Kick off copy back to the CUDA tensors.
for (size_t i = 0; i < inputs.size(); i++) {
guard.reset_stream(streams[i]);
inputs[i].copy_(tmp, /* non_blocking */ true);
events[i].record(streams[i]);
}
}
void synchronize() override {
at::cuda::OptionalCUDAGuard guard;
// Synchronize with the copy back to CUDA tensors.
for (size_t i = 0; i < inputs.size(); i++) {
guard.set_index(inputs[i].device().index());
events[i].block(at::cuda::getCurrentCUDAStream());
}
}
at::Tensor tmp;
std::vector<at::cuda::CUDAStream> streams;
std::vector<at::cuda::CUDAEvent> events;
};
#endif
} // namespace
std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::broadcast(
std::vector<at::Tensor>& inputs,
const BroadcastOptions& opts) {
static auto invalidArgument = [](const std::string& msg) {
throw std::invalid_argument("ProcessGroupGloo::broadcast: " + msg);
};
assertRootRank(invalidArgument, opts.rootRank, size_);
assertRootTensor(invalidArgument, opts.rootTensor, inputs.size());
assertDense(invalidArgument, inputs);
assertTypeAndSizesMatch(invalidArgument, inputs);
const auto& device = inputs[0].device();
switch (device.type()) {
case at::kCPU:
#ifdef USE_CUDA
case at::kCUDA:
#endif
break;
default:
invalidArgument(c10::str("unsupported device type ", device.type()));
}
std::shared_ptr<AsyncBroadcastWork> work;
auto tag = nextTag();
auto context = getContext(tag);
if (device.type() == at::kCPU) {
work = std::make_shared<AsyncBroadcastWork>(
std::move(context), inputs, opts.rootRank, opts.rootTensor, tag);
#ifdef USE_CUDA
} else if (device.type() == at::kCUDA) {
work = std::make_shared<AsyncBroadcastCUDAWork>(
std::move(context), inputs, opts.rootRank, opts.rootTensor, tag);
#endif
} else {
throw std::runtime_error("Invalid backend");
}
enqueue(work);
return work;
}
namespace {
class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork {
public:
AsyncAllreduceWork(
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
ReduceOp reduceOp,
uint32_t tag)
: ProcessGroupGloo::AsyncWork("gloo:all_reduce"),
context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {}
std::shared_ptr<gloo::Context> context;
std::vector<at::Tensor> inputs;
const ReduceOp reduceOp;
const uint32_t tag;
void allreduce(std::vector<at::Tensor>& tensors) {
const auto& scalarType = tensors[0].scalar_type();
gloo::AllreduceOptions opts(context);
opts.setReduceFunction(getFunction(scalarType, reduceOp));
opts.setTag(tag);
GENERATE_ALL_TYPES(scalarType, setOutputs, opts, tensors);
gloo::allreduce(opts);
}
void run() override {
allreduce(inputs);
outputs_ = inputs;
}
template <typename T>
void getFunction(gloo::AllreduceOptions::Func& fn, const ReduceOp op) {
fn = toFunction<T>(op);
}
gloo::AllreduceOptions::Func getFunction(
const at::ScalarType& dtype,
const ReduceOp op) {
gloo::AllreduceOptions::Func fn;
GENERATE_ALL_TYPES(dtype, getFunction, fn, op);
return fn;
}
std::vector<at::Tensor> result() override {
TORCH_CHECK(
isCompleted(),
"Work needs to be completed before calling result(). "
"Should call wait() before result().");
return outputs_;
}
protected:
std::vector<at::Tensor> outputs_;
};
class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork {
public:
AsyncAllreduceCoalescedWork(
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
ReduceOp reduceOp,
uint32_t tag)
: AsyncAllreduceWork(context, inputs, reduceOp, tag) {}
void run() override {
allreduceCoalesced(inputs);
}
private:
void allreduceCoalesced(std::vector<at::Tensor>& tensors) {
// reduce coalesced, flattened tensors.
at::Tensor coalescedTensor = flattenDenseTensors(tensors);
std::vector<at::Tensor> allreduceInput = {coalescedTensor};
allreduce(allreduceInput);
// separate and reshape tensors.
size_t offset = 0;
for (at::Tensor& tensor : tensors) {
const int64_t tensorNumel = tensor.numel();
const c10::IntArrayRef tensorShape = tensor.sizes();
tensor.copy_(coalescedTensor.slice(0, offset, offset + tensorNumel)
.view(tensorShape));
offset += tensorNumel;
}
}
};
class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
public:
AsyncSparseAllreduceWork(
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
uint32_t tag)
: context(context), inputs(inputs), tag(tag) {}
std::shared_ptr<gloo::Context> context;
std::vector<at::Tensor> inputs;
std::vector<at::Tensor> outputs;
const uint32_t tag;
// We share dimensionality about the sparse tensors before collecting
// their contents. We assume here that the maximum number of sparse
// and dense dimensions is 4. This is stored in a contiguous piece of
// memory so that we can easily run allgather on it.
//
// The layout of this memory is as follows:
//
// - [0:4]: sparse dims
// - [4:8]: dense dims
// - [8]: nnz
//
class SparseTensorMetadata {
public:
static constexpr auto dim = 9;
// Construct from an existing metadata tensor to facilitate structured
// access to metadata from peers, after gathering it.
explicit SparseTensorMetadata(at::Tensor metadata)
: metadata_(metadata), data_(metadata_.data_ptr<int64_t>()) {
AT_ASSERT(metadata.scalar_type() == at::kLong);
AT_ASSERT(metadata.dim() == 1);
AT_ASSERT(metadata.size(0) == dim);
}
// Populate the metadata.
void populate_from_sparse_tensor(const at::Tensor& tensor) {
const auto sparse_dim = tensor.sparse_dim();
AT_ASSERT(sparse_dim <= 4);
for (auto i = 0; i < 4; i++) {
if (i < sparse_dim) {
data_[i] = tensor.size(i);
}
}
const auto dense_dim = tensor.dense_dim();
AT_ASSERT(dense_dim <= 4);
for (auto i = 0; i < 4; i++) {
if (i < dense_dim) {
data_[i + 4] = tensor.size(sparse_dim + i);
}
}
data_[8] = tensor._nnz();
}
std::vector<int64_t> sizes() const {
std::vector<int64_t> sizes;
// Sparse sizes
for (auto i = 0; i < 4; i++) {
if (data_[i] <= 0) {
break;
}
sizes.push_back(data_[i]);
}
// Dense sizes
for (auto i = 4; i < 8; i++) {
if (data_[i] <= 0) {
break;
}
sizes.push_back(data_[i]);
}
return sizes;
}
int64_t nnz() const {
return data_[8];
}
protected:
at::Tensor metadata_;
int64_t* data_;
};
// Sparse allreduce is implemented with allgather on indices and values.
// Every process then sums the resulting sparse tensors locally.
// The nnz for sparse tensors may be different across processes, so first
// we run allgather on the nnz, and then allgather with max(nnz).
// We could use an allgatherv for this, if it were available.
at::Tensor allreduce(std::vector<at::Tensor>& tensors) {
// TODO: This is a massive hack! There is some confusion about
// Variable/Tensor inside the body of this function. Turning off
// grad smooths over the confusion for now. This fixes
// test/test_c10d.py ProcessGroupGlooTest.test_sparse_allreduce_basics