From 160db3db4fa97f1dad0c029c53ff175c022edc1d Mon Sep 17 00:00:00 2001 From: Mehdi Mirzazadeh Date: Fri, 6 Nov 2020 09:48:28 -0800 Subject: [PATCH] Adding profiling capability to c++ ddp collective functions (#46471) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46471 ghstack-source-id: 116018837 Test Plan: Added unit tests: buck test mode/dev-nosan caffe2/test/distributed:distributed_gloo_fork buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork Reviewed By: rohan-varma Differential Revision: D23948397 fbshipit-source-id: 6d93a370aff26bf96c39e5d78a2492c5142a9156 --- torch/csrc/distributed/c10d/init.cpp | 3 +- torch/lib/c10d/ProcessGroup.cpp | 26 ++- torch/lib/c10d/ProcessGroup.hpp | 8 +- torch/lib/c10d/ProcessGroupGloo.cpp | 27 ++- torch/lib/c10d/ProcessGroupGloo.hpp | 2 + torch/lib/c10d/ProcessGroupNCCL.cpp | 54 ++++-- torch/lib/c10d/ProcessGroupNCCL.hpp | 11 +- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 6 +- torch/lib/c10d/test/ProcessGroupNCCLTest.cpp | 10 +- .../_internal/distributed/distributed_test.py | 171 +++++++++++++++--- 10 files changed, 252 insertions(+), 66 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 4c031dde960e..d9ddf35ee1df 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1086,8 +1086,7 @@ that adds a prefix to each key inserted to the store. &::c10d::ProcessGroup::Work::wait, py::arg("timeout") = kNoTimeout, py::call_guard()) - .def( - "get_future", + .def("get_future", [](::c10d::ProcessGroup::Work& work) -> std::shared_ptr { return std::make_shared(work.getFuture()); diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index 83035666d7e9..3521ed42c840 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -1,4 +1,6 @@ #include +#include + #include @@ -51,10 +53,20 @@ bool isP2POp(OpType opType) { opType == OpType::RECVANYSOURCE; } -ProcessGroup::Work::Work() : rank_(-1), opType_(OpType::UNKNOWN) {} -ProcessGroup::Work::Work(int rank, OpType opType) - : rank_(rank), opType_(opType) {} +ProcessGroup::Work::Work(int rank, OpType opType, const char* profilingTitle) + : rank_(rank), opType_(opType) { + if (profilingTitle != nullptr) { + auto recordingFunction = std::make_shared(at::RecordScope::USER_SCOPE); + if (recordingFunction->active) { + recordingFunction->before(profilingTitle, {}); + std::function end_handler = [this, recordingFunction]() { + recordingFunction->end(); + }; + recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler); + } + } +} OpType ProcessGroup::Work::retrieveOpType() { return opType_; @@ -123,6 +135,10 @@ void ProcessGroup::Work::finish(std::exception_ptr exception) { std::unique_lock lock(mutex_); completed_ = true; exception_ = exception; + if (recordFunctionEndCallback_) { + recordFunctionEndCallback_(); + recordFunctionEndCallback_ = nullptr; + } lock.unlock(); cv_.notify_all(); } @@ -131,6 +147,10 @@ void ProcessGroup::Work::finishAndThrow(std::exception_ptr exception) { std::unique_lock lock(mutex_); completed_ = true; exception_ = exception; + if (recordFunctionEndCallback_) { + recordFunctionEndCallback_(); + recordFunctionEndCallback_ = nullptr; + } if (exception_) { std::rethrow_exception(exception_); } diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 774fcd97262a..5e90dccc25c0 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -77,9 +77,7 @@ class ProcessGroup { // this will be bound using pybind. class Work { public: - Work(); - - Work(int rank, OpType opType); + Work(int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr); virtual ~Work(); @@ -156,6 +154,10 @@ class ProcessGroup { // Operation type that this work object refers to. OpType opType_; + + // When profiling, the callback to record end of operation event. This + // callback needs to be called when collective operation is complete. + std::function recordFunctionEndCallback_; }; explicit ProcessGroup(int rank, int size); diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index c139ac7a34fd..cd3e83e6b714 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -677,7 +677,8 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { int rootRank, int rootTensor, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:broadcast"), + context(context), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), @@ -823,7 +824,8 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { std::vector& inputs, ReduceOp reduceOp, uint32_t tag) - : context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {} + : ProcessGroupGloo::AsyncWork("gloo:all_reduce"), + context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {} std::shared_ptr context; std::vector inputs; @@ -1431,7 +1433,8 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { int rootTensor, ReduceOp reduceOp, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:reduce"), + context(context), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), @@ -1595,7 +1598,8 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { std::vector>& outputs, std::vector& inputs, uint32_t tag) - : context(context), outputs(outputs), inputs(inputs), tag(tag) {} + : ProcessGroupGloo::AsyncWork("gloo:all_gather"), + context(context), outputs(outputs), inputs(inputs), tag(tag) {} std::shared_ptr context; std::vector> outputs; @@ -1792,7 +1796,8 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { std::vector>& output_lists, std::vector& input_list, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:all_gather"), + context(context), output_lists(output_lists), input_list(input_list), tag(tag) {} @@ -1921,7 +1926,8 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { std::vector& inputs, int root, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:gather"), + context(context), outputs(outputs), inputs(inputs), root(root), @@ -2125,7 +2131,8 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { std::vector>& inputs, int root, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:scatter"), + context(context), outputs(outputs), inputs(inputs), root(root), @@ -2319,7 +2326,8 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { std::vector& outputCounts, std::vector& inputCounts, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:all_to_all"), + context(context), outputTensor(outputTensor), inputTensor(inputTensor), outputCounts(std::move(outputCounts)), @@ -2576,7 +2584,8 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { const std::shared_ptr& context, std::vector> priorWork, uint32_t tag) - : context(context), priorWork(std::move(priorWork)), tag(tag) {} + : ProcessGroupGloo::AsyncWork("gloo:barrier"), + context(context), priorWork(std::move(priorWork)), tag(tag) {} std::shared_ptr context; std::vector> priorWork; diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index dfae068de244..31664ad0b6cf 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -68,6 +68,8 @@ class ProcessGroupGloo : public ProcessGroup { // class AsyncWork : public ProcessGroup::Work { public: + AsyncWork(const char* profilingTitle = nullptr): ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle) {} + static void execute(std::shared_ptr work) { std::exception_ptr eptr; try { diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 25bd338138a6..ba0b4b36c77d 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -240,8 +240,9 @@ std::ostream& operator<<( ProcessGroupNCCL::WorkNCCL::WorkNCCL( const std::vector& devices, int rank, - OpType opType) - : Work(rank, opType), + OpType opType, + const char* profilingTitle) + : Work(rank, opType, profilingTitle), devices_(devices), workStartTime_(std::chrono::steady_clock::now()) { // Creates the CUDA event wrappers @@ -986,8 +987,9 @@ std::vector flatten_for_scatter_gather( std::shared_ptr ProcessGroupNCCL::initWork( std::vector devices, int rank, - OpType opType) { - return std::make_shared(devices, rank, opType); + OpType opType, + const char* profilingTitle) { + return std::make_shared(devices, rank, opType, profilingTitle); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -1031,7 +1033,8 @@ std::shared_ptr ProcessGroupNCCL::collective( Fn fn, PreProcess pre, PostProcess post, - OpType opType) { + OpType opType, + const char* profilingTitle) { const auto devices = getDeviceList(inputs); const auto key = getKeyFromDevices(devices); auto& ncclComms = getNCCLComm(key, devices, opType); @@ -1040,13 +1043,25 @@ std::shared_ptr ProcessGroupNCCL::collective( syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); // Work itself will create the CUDA events on all GPUs of tensors - auto work = initWork(devices, rank_, opType); + bool can_profile = outputs.size() == 1; + auto work = initWork(devices, rank_, opType, can_profile ? profilingTitle : nullptr); // Store references to outputs and futureNCCLCallbackStream to be used by // WorkNCCL::getFuture. work->outputs_ = std::make_shared>(outputs); work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_; + if (work->recordFunctionEndCallback_) { + // recordFunctionEndCallback_ is normally called in fininsh() function by + // base class, but since finish is not called by WorkNCCL, we schedule this + // function to be run when work is done. + // Note when can_profile is false, profilingTitle is not provided and so, + // recordFunctionEndCallback_ is not set. + work->getFuture()->addCallback(std::move(work->recordFunctionEndCallback_)); + } + + + at::cuda::OptionalCUDAGuard gpuGuard; pre(ncclStreams_[key]); @@ -1175,14 +1190,16 @@ std::shared_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, - OpType opType) { + OpType opType, + const char* profilingTitle) { return collective( inputs, outputs, fn, [](std::vector&) {}, [](std::vector&) {}, - opType); + opType, + profilingTitle); } template @@ -1221,7 +1238,8 @@ std::shared_ptr ProcessGroupNCCL::allreduce( comm, stream.stream()); }, - OpType::ALLREDUCE); + OpType::ALLREDUCE, + "nccl:all_reduce"); } std::shared_ptr ProcessGroupNCCL::allreduce_coalesced( @@ -1252,7 +1270,8 @@ std::shared_ptr ProcessGroupNCCL::broadcast( comm, stream.stream()); }, - OpType::BROADCAST); + OpType::BROADCAST, + "nccl:broadcast"); } std::shared_ptr ProcessGroupNCCL::reduce( @@ -1278,7 +1297,8 @@ std::shared_ptr ProcessGroupNCCL::reduce( comm, stream.stream()); }, - OpType::REDUCE); + OpType::REDUCE, + "nccl:reduce"); } std::shared_ptr ProcessGroupNCCL::allgather( @@ -1322,7 +1342,8 @@ std::shared_ptr ProcessGroupNCCL::allgather( } } }, - OpType::ALLGATHER); + OpType::ALLGATHER, + "nccl:all_gather"); } std::shared_ptr ProcessGroupNCCL::allgather_coalesced( @@ -1375,7 +1396,8 @@ std::shared_ptr ProcessGroupNCCL::reduce_scatter( } }, [&](std::vector& ncclStreams) {}, - OpType::REDUCE_SCATTER); + OpType::REDUCE_SCATTER, + "nccl:reduce_scatter"); } std::shared_ptr ProcessGroupNCCL::barrier( @@ -1448,7 +1470,8 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( stream); return ncclSuccess; }, - OpType::ALLTOALL_BASE); + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); @@ -1484,7 +1507,8 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( comm, stream.stream()); }, - OpType::ALLTOALL_BASE); + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); } } diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index e6f589275a11..1520604629f2 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -68,7 +68,7 @@ class ProcessGroupNCCL : public ProcessGroup { public std::enable_shared_from_this { public: // Constructor takes a list of CUDA devices - WorkNCCL(const std::vector& devices, int rank, OpType opType); + WorkNCCL(const std::vector& devices, int rank, OpType opType, const char* profilingTitle = nullptr); // Copy constructor doing partial copy without outputs_. Cleanup thread // monitors and removes finished works. However it will deadlock when // destructs outputs_ tensors who are view tensors in autograd graph. @@ -518,7 +518,8 @@ class ProcessGroupNCCL : public ProcessGroup { virtual std::shared_ptr initWork( std::vector devices, int rank, - OpType opType); + OpType opType, + const char* profilingTitle=nullptr); private: // Helper that encapsulates work shared across all collective communication @@ -532,7 +533,8 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& input, std::vector& output, Fn fn, - OpType opType); + OpType opType, + const char* profilingTitle = nullptr); template std::shared_ptr collective( std::vector& input, @@ -540,7 +542,8 @@ class ProcessGroupNCCL : public ProcessGroup { Fn fn, PreProcess pre, PostProcess post, - OpType opType); + OpType opType, + const char* profilingTitle = nullptr); // Helper that encapsulates work shared across point-to-point communication // primitives. It is the same structure as the helper used for collective diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index 0df197d17cbb..e906702a889d 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -59,7 +59,8 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { std::shared_ptr initWork( std::vector devices, int rank, - c10d::OpType opType) override { + c10d::OpType opType, + const char* profilingTitle) override { return std::make_shared( devices, simulate_error_, rank, opType); } @@ -115,7 +116,8 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { std::shared_ptr initWork( std::vector devices, int rank, - c10d::OpType opType) override { + c10d::OpType opType, + const char* profilingTitle) override { return std::make_shared( devices, set_timedout_error_, rank, opType); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index 16cc778325a1..92b477fae7de 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -9,6 +9,7 @@ #include #include +#include #include using namespace c10d::test; @@ -172,7 +173,14 @@ class AllreduceNCCLTest : public NCCLTest { launchDeviceSleep(); valueInitialization(); - return pg_->allreduce(tensors_); + using namespace torch::autograd::profiler; + // Make sure enabling profile does not make any issue. Note, in single + // process multi-device mode we do not expect any events be populated for + // collective operations, since profiling for that mode is not supported. + enableProfiler({ProfilerState::CPU}); + auto results = pg_->allreduce(tensors_); + disableProfiler(); + return results; } }; diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 1ef3400584b2..9ce1c58cb4da 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -953,9 +953,9 @@ def _test_broadcast_helper( opts = dist.BroadcastOptions() opts.rootTensor = 0 opts.rootRank = src - group_id.broadcast([expected_tensor], opts).wait() + self.call_dist_op(":broadcast", True, group_id.broadcast, [expected_tensor], opts) else: - dist.broadcast(expected_tensor, src, group_id) + self.call_dist_op(":broadcast", False, dist.broadcast, expected_tensor, src, group_id) else: tensor = _build_tensor(src + 1, -1, dtype) if cuda: @@ -964,9 +964,9 @@ def _test_broadcast_helper( opts = dist.BroadcastOptions() opts.rootTensor = 0 opts.rootRank = src - group_id.broadcast([tensor], opts).wait() + self.call_dist_op(":broadcast", True, group_id.broadcast, [tensor], opts) else: - dist.broadcast(tensor, src, group_id) + self.call_dist_op(":broadcast", False, dist.broadcast, tensor, src, group_id) self.assertEqual(tensor.size(), expected_tensor.size()) self.assertEqual(tensor.ne(expected_tensor).max(), torch.tensor(False)) @@ -1034,17 +1034,12 @@ def _test_reduce_helper( rank_to_GPU=None, ): for src in group: + tensor = _build_tensor(src + 1).fill_(master_value if rank == src else worker_value) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + self.call_dist_op(":reduce", False, dist.reduce, tensor, src, op, group_id) if rank == src: - tensor = _build_tensor(src + 1).fill_(master_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.reduce(tensor, src, op, group_id) self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) - else: - tensor = _build_tensor(src + 1).fill_(worker_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.reduce(tensor, src, op, group_id) self._barrier() @@ -1178,6 +1173,64 @@ def test_reduce_full_group_max(self): group, group_id, rank = self._init_full_group_test() self._test_reduce_helper(group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10) + # REDUCE TWICE + def _test_reduce_twice_helper( + self, + group, + group_id, + rank, + op, + master_value, + worker_value, + expected_value, + cuda=False, + rank_to_GPU=None, + ): + for src in group: + tensors = [_build_tensor(src + 1).fill_(master_value if rank == src else worker_value) for i in range(2)] + if cuda: + for i in range(2): + tensors[i] = tensors[i].cuda(rank_to_GPU[rank][0]) + self.call_dist_op(":reduce", False, dist.reduce, tensors[0], src, op, group_id, + secondary_op_call=lambda: dist.reduce(tensors[1], src, op, group_id)) + if rank == src: + for tensor in tensors: + self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) + + self._barrier() + + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_reduce_sum_twice(self): + group, group_id, rank = self._init_global_test() + self._test_reduce_twice_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + ) + + @unittest.skipIf(BACKEND != "nccl", "Only Nccl supports CUDA reduce") + @skip_if_no_gpu + @skip_if_rocm + def test_reduce_sum_cuda_twice(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_reduce_twice_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + 10 * (len(group) - 1), + True, + rank_to_GPU, + ) + + @skip_if_no_gpu @require_backend({"gloo", "nccl"}) @skip_if_rocm @@ -1225,6 +1278,29 @@ def test_all_reduce_result_cuda(self): self.assertEqual(result, [_build_tensor(src + 1, expected_value)]) self._barrier() + def call_dist_op(self, profiling_title_postfix, is_async, op, *args, expect_event=True, secondary_op_call=None, **kwargs): + op_calls = [lambda: op(*args, **kwargs)] + if secondary_op_call is not None: + op_calls.append(secondary_op_call) + + with torch.autograd.profiler.profile() as prof: + works = [op_call() for op_call in op_calls] + if is_async: + for work in works: + work.wait() + + def get_event(postfix): + return [event for event in prof.function_events if event.name.endswith(postfix)] + + events = get_event(profiling_title_postfix) + if expect_event: + self.assertEqual(len(events), len(op_calls)) + for e in events: + self.assertEqual(e.count, 1) + self.assertGreater(e.cpu_time, 0) + else: + self.assertEqual([], events) + # ALL REDUCE def _test_all_reduce_helper( self, @@ -1238,6 +1314,7 @@ def _test_all_reduce_helper( cuda=False, rank_to_GPU=None, dtype=torch.float, + async_op=False, ): for src in group: curr_value = master_value if rank == src else worker_value @@ -1245,9 +1322,7 @@ def _test_all_reduce_helper( tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value) if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) - self.assertEqual(tensor, expected_tensor) + self.call_dist_op(":all_reduce", async_op, dist.all_reduce, tensor, op, group_id, async_op=async_op) self._barrier() @@ -1264,6 +1339,20 @@ def test_all_reduce_sum(self): 2 + (10 * (len(group) - 1)), ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_sum_async(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + async_op=True + ) + @unittest.skipIf( BACKEND != "gloo", "Only Gloo backend will have CUDA allReduce tested", @@ -1284,6 +1373,27 @@ def test_all_reduce_sum_cuda(self): rank_to_GPU, ) + @unittest.skipIf( + BACKEND != "gloo" and BACKEND != "nccl", + "Only Gloo and NCCL backends will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda_async(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + True, + rank_to_GPU, + async_op=True + ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_reduce_sum_complex(self): group, group_id, rank = self._init_global_test() @@ -1531,7 +1641,7 @@ def _test_all_reduce_coalesced_helper( ] if cuda: tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] - dist.all_reduce_coalesced(tensors, op, group_id) + self.call_dist_op(":all_reduce", False, dist.all_reduce_coalesced, tensors, op, group_id) expected_tensors = [ _build_tensor(src + 1, expected_value, dtype=dtype) for dtype, expected_value in zip(dtypes, expected_values) @@ -1699,7 +1809,7 @@ def _test_scatter_helper(self, group, group_id, rank): tensors = ( [_build_tensor(dest + 1, i) for i in group] if rank == dest else [] ) - dist.scatter(tensor, src=dest, scatter_list=tensors, group=group_id) + self.call_dist_op(":scatter", False, dist.scatter, tensor, src=dest, scatter_list=tensors, group=group_id) self.assertEqual(tensor, expected_tensor) self._barrier() @@ -1750,7 +1860,7 @@ def _test_gather_helper(self, group, group_id, rank): tensors = ( [_build_tensor(dest + 1, -1) for i in group] if rank == dest else [] ) - dist.gather(tensor, dst=dest, gather_list=tensors, group=group_id) + self.call_dist_op(":gather", False, dist.gather, tensor, dst=dest, gather_list=tensors, group=group_id) if rank == dest: expected_tensors = [_build_tensor(dest + 1, i) for i in group] for t1, t2 in zip(tensors, expected_tensors): @@ -1808,7 +1918,7 @@ def _test_all_gather_helper( if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] - dist.all_gather(tensors, tensor, group_id) + self.call_dist_op(":all_gather", False, dist.all_gather, tensors, tensor, group_id) expected_tensors = [_build_tensor(dest + 1, i, dtype=dtype) for i in group] for t1, t2 in zip(tensors, expected_tensors): @@ -1860,8 +1970,8 @@ def _run_all_gather_coalesced_and_verify( Helper that runs all_gather_coalesced and returns true if output matches expectations. """ - dist.all_gather_coalesced( - output_tensor_lists, input_tensors, group_id) + self.call_dist_op(":all_gather", False, dist.all_gather_coalesced, + output_tensor_lists, input_tensors, group_id) for l1, l2 in zip(output_tensor_lists, expected_tensors): for t1, t2 in zip(l1, l2): @@ -1986,7 +2096,7 @@ def _test_all_to_all_single_equal_split_helper( in_tensor = in_tensor.cuda(rank_to_GPU[rank][0]) expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0]) out_tensor = out_tensor.cuda(rank_to_GPU[rank][0]) - dist.all_to_all_single(out_tensor, in_tensor, group=group_id) + self.call_dist_op(":all_to_all", False, dist.all_to_all_single, out_tensor, in_tensor, group=group_id) self.assertEqual(out_tensor, expected_tensor) self._barrier() @@ -2303,7 +2413,7 @@ def _test_all_reduce_multigpu_helper( _build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i) for i in rank_to_GPU[rank] ] - dist.all_reduce_multigpu(tensors, op, group_id) + self.call_dist_op(":all_reduce", False, dist.all_reduce_multigpu, tensors, op, group_id) expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) for tensor in tensors: self.assertEqual(tensor, expected_tensor) @@ -2362,7 +2472,9 @@ def _test_reduce_multigpu_helper( _build_tensor(src + 1, master_value).cuda(device=i) for i in rank_to_GPU[rank] ] - dist.reduce_multigpu(tensors, src, op, group_id) + self.call_dist_op( + "reduce", False, dist.reduce_multigpu, tensors, src, op, group_id, + expect_event=len(tensors) == 1) expected_tensor = _build_tensor(src + 1, expected_value) self.assertEqual(tensors[0], expected_tensor) else: @@ -2370,7 +2482,9 @@ def _test_reduce_multigpu_helper( _build_tensor(src + 1, worker_value).cuda(device=i) for i in rank_to_GPU[rank] ] - dist.reduce_multigpu(tensors, src, op, group_id) + self.call_dist_op( + "reduce", False, dist.reduce_multigpu, tensors, src, op, group_id, + expect_event=len(tensors) == 1) self._barrier() @@ -2411,7 +2525,10 @@ def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU, d output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu]) expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu]) - dist.all_gather_multigpu(output_tensors, tensors, group_id) + self.call_dist_op( + "all_gather", False, + dist.all_gather_multigpu, output_tensors, tensors, group_id, + expect_event=len(expected_output) == 1) self.assertEqual(output_tensors, expected_output) self._barrier()