diff --git a/test/cpp_extensions/cpp_c10d_extension.cpp b/test/cpp_extensions/cpp_c10d_extension.cpp index b4901cdbcf4d..50e5f5861caa 100644 --- a/test/cpp_extensions/cpp_c10d_extension.cpp +++ b/test/cpp_extensions/cpp_c10d_extension.cpp @@ -23,85 +23,85 @@ ProcessGroupTest::ProcessGroupTest(int rank, int size) ProcessGroupTest::~ProcessGroupTest() {} -std::shared_ptr ProcessGroupTest::broadcast( +c10::intrusive_ptr ProcessGroupTest::broadcast( std::vector& tensors, const BroadcastOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::allreduce( +c10::intrusive_ptr ProcessGroupTest::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupTest::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced"); } -std::shared_ptr ProcessGroupTest::reduce( +c10::intrusive_ptr ProcessGroupTest::reduce( std::vector& tensors, const ReduceOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce"); } -std::shared_ptr ProcessGroupTest::allgather( +c10::intrusive_ptr ProcessGroupTest::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather"); } -std::shared_ptr ProcessGroupTest::allgather_base( +c10::intrusive_ptr ProcessGroupTest::allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather_base"); } -std::shared_ptr ProcessGroupTest::barrier( +c10::intrusive_ptr ProcessGroupTest::barrier( const BarrierOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::gather( +c10::intrusive_ptr ProcessGroupTest::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support gather"); } -std::shared_ptr ProcessGroupTest::scatter( +c10::intrusive_ptr ProcessGroupTest::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support scatter"); } -std::shared_ptr ProcessGroupTest::reduce_scatter( +c10::intrusive_ptr ProcessGroupTest::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce_scatter"); } -std::shared_ptr ProcessGroupTest::send( +c10::intrusive_ptr ProcessGroupTest::send( std::vector& tensors, int dstRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support send"); } -std::shared_ptr ProcessGroupTest::recv( +c10::intrusive_ptr ProcessGroupTest::recv( std::vector& tensors, int srcRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support recv"); } -std::shared_ptr ProcessGroupTest::recvAnysource( +c10::intrusive_ptr ProcessGroupTest::recvAnysource( std::vector& tensor, int tag) { throw std::runtime_error("ProcessGroupTest does not support recvAnysource"); diff --git a/test/cpp_extensions/cpp_c10d_extension.hpp b/test/cpp_extensions/cpp_c10d_extension.hpp index d8dffcd20327..8aeec736d440 100644 --- a/test/cpp_extensions/cpp_c10d_extension.hpp +++ b/test/cpp_extensions/cpp_c10d_extension.hpp @@ -41,61 +41,61 @@ class ProcessGroupTest : public ProcessGroup { explicit ProcessGroupTest(int rank = -1, int size = -1); virtual ~ProcessGroupTest(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag); - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag); - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensor, int tag); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index d9ddf35ee1df..136efd32fc87 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1,5 +1,6 @@ #include +#include #include #ifndef _WIN32 #include @@ -59,6 +60,8 @@ constexpr auto kDeprecationWarning = "{} API is being deprecated, please ping " "https://github.com/pytorch/pytorch/issues/46291 " "if you see this warning"; +template +using intrusive_ptr_class_ = py::class_>; // PythonStore is a pybind11 trampoline class to allow a Python // class to inherit from c10d.Store and implement its interface. @@ -1045,7 +1048,7 @@ that adds a prefix to each key inserted to the store. py::call_guard()); #endif - shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") + intrusive_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") .def("is_completed", &::c10d::ProcessGroup::Work::isCompleted) .def( "is_success", diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 2f29adc8f0c4..13e685b8fe74 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -398,7 +398,7 @@ void ProcessGroupAgent::handleSend(const SendWork& work) { // ProcessGroup is not thread-safe when sending with the same tag, // hence the lock - std::vector> pendingSends; + std::vector> pendingSends; const auto dst = work.to_.id_; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 1bc8db9ebf20..70fb1b40244d 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -230,14 +230,14 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // Lock and shared ptr to currently pending work, set in listenloop() and // interruptible in shutdown(). std::mutex recvWorkMutex_; - std::shared_ptr recvWork_; + c10::intrusive_ptr recvWork_; // Map of dst rank to current oustanding sends that we are waiting on. In the // case of a call to ::shutdown() while we are still waiting on these sends, // the pending sends contained in this map will be aborted, allowing the // waiting thread to be unblocked. std::unordered_map< worker_id_t, - std::set>> + std::set>> currentPendingSends_; // Lock to serialize access to the above map. std::mutex pendingSendMutex_; diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index 3521ed42c840..1d0d451f21a9 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -164,7 +164,7 @@ ProcessGroup::~ProcessGroup() {} // This is introduced so that implementors of ProcessGroup would not need to // have this implmentation. -std::shared_ptr ProcessGroup::allgather_coalesced( +c10::intrusive_ptr ProcessGroup::allgather_coalesced( std::vector>& /* usused */, std::vector& /* usused */, const AllgatherOptions& /* usused */) { diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 5e90dccc25c0..63996b516a06 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -70,12 +70,11 @@ bool isP2POp(OpType opType); // class ProcessGroup { public: - // Please do not use ProcessGroup::Work API, it is going away, to be // replaced by ivalue::Future. // Python binding for this class might change, please do not assume // this will be bound using pybind. - class Work { + class Work : public torch::CustomClassHolder { public: Work(int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr); @@ -171,25 +170,25 @@ class ProcessGroup { return size_; } - virtual std::shared_ptr broadcast( + virtual c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) = 0; - virtual std::shared_ptr allreduce( + virtual c10::intrusive_ptr allreduce( std::vector& data, const AllreduceOptions& opts = AllreduceOptions()) = 0; // This will be moved out of ProcessGroup, do not add dependencies on this // function. - virtual std::shared_ptr allreduce_coalesced( + virtual c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) = 0; - virtual std::shared_ptr reduce( + virtual c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) = 0; - virtual std::shared_ptr allgather( + virtual c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -197,7 +196,7 @@ class ProcessGroup { // Gathers a single tensor inputBuffer into a single buffer outputBuffer that // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. - virtual std::shared_ptr allgather_base( + virtual c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -206,27 +205,27 @@ class ProcessGroup { // * do not add dependencies on this function, // * do not implement it in your ProcessGroup, implement allgather_base // instead. - virtual std::shared_ptr allgather_coalesced( + virtual c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()); - virtual std::shared_ptr gather( + virtual c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) = 0; - virtual std::shared_ptr scatter( + virtual c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) = 0; - virtual std::shared_ptr reduce_scatter( + virtual c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0; - virtual std::shared_ptr alltoall_base( + virtual c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -235,28 +234,28 @@ class ProcessGroup { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual std::shared_ptr alltoall( + virtual c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual std::shared_ptr send( + virtual c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) = 0; - virtual std::shared_ptr recv( + virtual c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) = 0; - virtual std::shared_ptr recvAnysource( + virtual c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) = 0; - virtual std::shared_ptr barrier( + virtual c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) = 0; protected: diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index cd3e83e6b714..90c9b695de28 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -38,6 +38,7 @@ #endif #include +#include #include #include #include @@ -653,11 +654,11 @@ void ProcessGroupGloo::runLoop(int workerIndex) { AsyncWork::execute(std::move(work)); lock.lock(); - workInProgress_[workerIndex] = nullptr; + workInProgress_[workerIndex].reset(); } } -void ProcessGroupGloo::enqueue(std::shared_ptr work) { +void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); workQueue_.push_back(std::move(work)); lock.unlock(); @@ -773,7 +774,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { } // namespace -std::shared_ptr ProcessGroupGloo::broadcast( +c10::intrusive_ptr ProcessGroupGloo::broadcast( std::vector& inputs, const BroadcastOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -796,15 +797,15 @@ std::shared_ptr ProcessGroupGloo::broadcast( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #endif } else { @@ -1300,7 +1301,7 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { } // namespace -std::shared_ptr ProcessGroupGloo::allreduce( +c10::intrusive_ptr ProcessGroupGloo::allreduce( std::vector& inputs, const AllreduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1329,15 +1330,15 @@ std::shared_ptr ProcessGroupGloo::allreduce( "(allreduce of sparse tensors only works with ReduceOp.SUM)"); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1345,10 +1346,10 @@ std::shared_ptr ProcessGroupGloo::allreduce( #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1362,7 +1363,7 @@ std::shared_ptr ProcessGroupGloo::allreduce( return work; } -std::shared_ptr ProcessGroupGloo::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupGloo::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1405,12 +1406,12 @@ std::shared_ptr ProcessGroupGloo::allreduce_coalesced( invalidArgument("unsupported layout"); } - std::shared_ptr work; + c10::intrusive_ptr work; const uint32_t tag = nextTag(); std::shared_ptr context = getContext(tag); if (device.type() == c10::kCPU) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), tensors, opts.reduceOp, tag); } else { invalidArgument("unsupported layout"); @@ -1538,7 +1539,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { } // namespace -std::shared_ptr ProcessGroupGloo::reduce( +c10::intrusive_ptr ProcessGroupGloo::reduce( std::vector& inputs, const ReduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1561,11 +1562,11 @@ std::shared_ptr ProcessGroupGloo::reduce( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, @@ -1574,7 +1575,7 @@ std::shared_ptr ProcessGroupGloo::reduce( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, @@ -1720,7 +1721,7 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { // Note: current CUDA implementation holds the assumption that the // tensors in the nested output tensor vectors are on the same device. -std::shared_ptr ProcessGroupGloo::allgather( +c10::intrusive_ptr ProcessGroupGloo::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { @@ -1769,15 +1770,15 @@ std::shared_ptr ProcessGroupGloo::allgather( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, tag); #endif } else { @@ -1852,7 +1853,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { } // namespace -std::shared_ptr ProcessGroupGloo::allgather_coalesced( +c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& /* unused */) { @@ -1902,13 +1903,13 @@ std::shared_ptr ProcessGroupGloo::allgather_coalesced( auto tag = nextTag(); auto context = getContext(tag); - auto work = std::make_shared( + auto work = c10::make_intrusive( std::move(context), output_lists, input_list, tag); enqueue(work); return work; } -std::shared_ptr ProcessGroupGloo::allgather_base( +c10::intrusive_ptr ProcessGroupGloo::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { @@ -2057,7 +2058,7 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { } // namespace -std::shared_ptr ProcessGroupGloo::gather( +c10::intrusive_ptr ProcessGroupGloo::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { @@ -2103,15 +2104,15 @@ std::shared_ptr ProcessGroupGloo::gather( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2245,7 +2246,7 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { } // namespace -std::shared_ptr ProcessGroupGloo::scatter( +c10::intrusive_ptr ProcessGroupGloo::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { @@ -2290,15 +2291,15 @@ std::shared_ptr ProcessGroupGloo::scatter( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2308,7 +2309,7 @@ std::shared_ptr ProcessGroupGloo::scatter( return work; } -std::shared_ptr ProcessGroupGloo::reduce_scatter( +c10::intrusive_ptr ProcessGroupGloo::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { @@ -2443,7 +2444,7 @@ class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { } // namespace -std::shared_ptr ProcessGroupGloo::alltoall_base( +c10::intrusive_ptr ProcessGroupGloo::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, @@ -2460,12 +2461,12 @@ std::shared_ptr ProcessGroupGloo::alltoall_base( assertDense(invalidArgument, {inputTensor}); const auto& device = outputTensor.device(); - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputTensor, inputTensor, @@ -2474,7 +2475,7 @@ std::shared_ptr ProcessGroupGloo::alltoall_base( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputTensor, inputTensor, @@ -2510,7 +2511,7 @@ uint32_t checkTag(int32_t tag) { return (uint32_t)tag; } -std::shared_ptr ProcessGroupGloo::send( +c10::intrusive_ptr ProcessGroupGloo::send( std::vector& tensors, int dstRank, int tag) { @@ -2526,10 +2527,10 @@ std::shared_ptr ProcessGroupGloo::send( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the send. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } -std::shared_ptr ProcessGroupGloo::recv( +c10::intrusive_ptr ProcessGroupGloo::recv( std::vector& tensors, int srcRank, int tag) { @@ -2545,10 +2546,10 @@ std::shared_ptr ProcessGroupGloo::recv( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } -std::shared_ptr ProcessGroupGloo::recvAnysource( +c10::intrusive_ptr ProcessGroupGloo::recvAnysource( std::vector& tensors, int tag) { auto& tensor = checkSingleTensor(tensors); @@ -2573,7 +2574,7 @@ std::shared_ptr ProcessGroupGloo::recvAnysource( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } namespace { @@ -2582,13 +2583,13 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { public: AsyncBarrierWork( const std::shared_ptr& context, - std::vector> priorWork, + std::vector> priorWork, uint32_t tag) : ProcessGroupGloo::AsyncWork("gloo:barrier"), context(context), priorWork(std::move(priorWork)), tag(tag) {} std::shared_ptr context; - std::vector> priorWork; + std::vector> priorWork; const uint32_t tag; void run() override { @@ -2608,9 +2609,9 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { } // namespace -std::shared_ptr ProcessGroupGloo::barrier( +c10::intrusive_ptr ProcessGroupGloo::barrier( const BarrierOptions& opts) { - std::vector> priorWork; + std::vector> priorWork; // Snapshot all in progress and pending work as weak_ptr. // When executing a barrier, we need to ensure that all prior work @@ -2624,7 +2625,7 @@ std::shared_ptr ProcessGroupGloo::barrier( auto tag = nextTag(); auto context = getContext(tag); - auto work = std::make_shared( + auto work = c10::make_intrusive( std::move(context), std::move(priorWork), tag); enqueue(work); return work; diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 31664ad0b6cf..74fd0f6e5165 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -70,7 +70,7 @@ class ProcessGroupGloo : public ProcessGroup { public: AsyncWork(const char* profilingTitle = nullptr): ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle) {} - static void execute(std::shared_ptr work) { + static void execute(c10::intrusive_ptr work) { std::exception_ptr eptr; try { work->run(); @@ -159,75 +159,75 @@ class ProcessGroupGloo : public ProcessGroup { virtual ~ProcessGroupGloo(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, std::vector& inputCounts, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; protected: @@ -258,7 +258,7 @@ class ProcessGroupGloo : public ProcessGroup { void runLoop(int workerIndex); // Queue work to run on worker thread. - void enqueue(std::shared_ptr work); + void enqueue(c10::intrusive_ptr work); // Keep both a queue of pending work, and a vector with in progress work. // Both of these can only be mutated when holding the queue lock. @@ -266,8 +266,8 @@ class ProcessGroupGloo : public ProcessGroup { // to all in progress and pending work when executing a barrier. // When executing a barrier, we need to ensure that all prior work // has completed before completing itself. - std::deque> workQueue_; - std::vector> workInProgress_; + std::deque> workQueue_; + std::vector> workInProgress_; std::mutex workMutex_; std::condition_variable workProduceCV_; std::condition_variable workConsumeCV_; diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp index d3e79a1dd424..5f9d0be41b8f 100644 --- a/torch/lib/c10d/ProcessGroupMPI.cpp +++ b/torch/lib/c10d/ProcessGroupMPI.cpp @@ -308,9 +308,9 @@ void ProcessGroupMPI::runLoop() { } } -std::shared_ptr ProcessGroupMPI::enqueue( +c10::intrusive_ptr ProcessGroupMPI::enqueue( std::unique_ptr entry) { - auto work = std::make_shared(); + auto work = c10::make_intrusive(); std::unique_lock lock(pgMutex_); queue_.push_back(std::make_tuple(std::move(entry), work)); lock.unlock(); @@ -318,7 +318,7 @@ std::shared_ptr ProcessGroupMPI::enqueue( return work; } -std::shared_ptr ProcessGroupMPI::broadcast( +c10::intrusive_ptr ProcessGroupMPI::broadcast( std::vector& tensors, const BroadcastOptions& opts) { checkSingleTensor(tensors); @@ -339,7 +339,7 @@ std::shared_ptr ProcessGroupMPI::broadcast( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allreduce( +c10::intrusive_ptr ProcessGroupMPI::allreduce( std::vector& tensors, const AllreduceOptions& opts) { checkSingleTensor(tensors); @@ -362,14 +362,14 @@ std::shared_ptr ProcessGroupMPI::allreduce( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupMPI::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with MPI"); } -std::shared_ptr ProcessGroupMPI::reduce( +c10::intrusive_ptr ProcessGroupMPI::reduce( std::vector& tensors, const ReduceOptions& opts) { checkSingleTensor(tensors); @@ -397,7 +397,7 @@ std::shared_ptr ProcessGroupMPI::reduce( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather( +c10::intrusive_ptr ProcessGroupMPI::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -441,7 +441,7 @@ std::shared_ptr ProcessGroupMPI::allgather( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather_coalesced( +c10::intrusive_ptr ProcessGroupMPI::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -449,7 +449,7 @@ std::shared_ptr ProcessGroupMPI::allgather_coalesced( "ProcessGroupMPI does not support allgather_coalesced"); } -std::shared_ptr ProcessGroupMPI::gather( +c10::intrusive_ptr ProcessGroupMPI::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { @@ -516,7 +516,7 @@ std::shared_ptr ProcessGroupMPI::gather( } } -std::shared_ptr ProcessGroupMPI::scatter( +c10::intrusive_ptr ProcessGroupMPI::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { @@ -582,14 +582,14 @@ std::shared_ptr ProcessGroupMPI::scatter( } } -std::shared_ptr ProcessGroupMPI::reduce_scatter( +c10::intrusive_ptr ProcessGroupMPI::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupMPI does not support reduce_scatter"); } -std::shared_ptr ProcessGroupMPI::alltoall_base( +c10::intrusive_ptr ProcessGroupMPI::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -665,7 +665,7 @@ std::shared_ptr ProcessGroupMPI::alltoall_base( return enqueue(std::move(entry)); } } -std::shared_ptr ProcessGroupMPI::alltoall( +c10::intrusive_ptr ProcessGroupMPI::alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts) { @@ -722,7 +722,7 @@ std::shared_ptr ProcessGroupMPI::alltoall( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::send( +c10::intrusive_ptr ProcessGroupMPI::send( std::vector& tensors, int dstRank, int tag) { @@ -744,10 +744,10 @@ std::shared_ptr ProcessGroupMPI::send( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::recv( +c10::intrusive_ptr ProcessGroupMPI::recv( std::vector& tensors, int srcRank, int tag) { @@ -769,10 +769,10 @@ std::shared_ptr ProcessGroupMPI::recv( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::recvAnysource( +c10::intrusive_ptr ProcessGroupMPI::recvAnysource( std::vector& tensors, int tag) { checkSingleTensor(tensors); @@ -793,10 +793,10 @@ std::shared_ptr ProcessGroupMPI::recvAnysource( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::barrier( +c10::intrusive_ptr ProcessGroupMPI::barrier( const BarrierOptions& opts) { std::function&)> runFunc = [this](std::unique_ptr& entry) { @@ -808,7 +808,7 @@ std::shared_ptr ProcessGroupMPI::barrier( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather_base( +c10::intrusive_ptr ProcessGroupMPI::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp index 342fe87001a0..48d95eada887 100644 --- a/torch/lib/c10d/ProcessGroupMPI.hpp +++ b/torch/lib/c10d/ProcessGroupMPI.hpp @@ -108,80 +108,80 @@ class ProcessGroupMPI : public ProcessGroup { // Abort the MPI program, needs to be called when exception is detected void abort(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr alltoall( + c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag); - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag); - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensor, int tag); - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized @@ -190,13 +190,13 @@ class ProcessGroupMPI : public ProcessGroup { protected: using WorkType = - std::tuple, std::shared_ptr>; + std::tuple, c10::intrusive_ptr>; // Worker thread loop void runLoop(); // Helper function that is called by the destructor void destroy(); - std::shared_ptr enqueue(std::unique_ptr entry); + c10::intrusive_ptr enqueue(std::unique_ptr entry); bool stop_; diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index ba0b4b36c77d..bd1563226343 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -984,12 +984,12 @@ std::vector flatten_for_scatter_gather( } // namespace -std::shared_ptr ProcessGroupNCCL::initWork( +c10::intrusive_ptr ProcessGroupNCCL::initWork( std::vector devices, int rank, OpType opType, const char* profilingTitle) { - return std::make_shared(devices, rank, opType, profilingTitle); + return c10::make_intrusive(devices, rank, opType); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -1012,7 +1012,7 @@ c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: } void ProcessGroupNCCL::workEnqueue( - std::shared_ptr work) { + c10::intrusive_ptr work) { if (!terminateProcessGroup_.load()) { std::lock_guard lock(workMetaListMutex_); // Avoid view tensors to be processed in cleanup thread. @@ -1027,7 +1027,7 @@ ProcessGroupNCCL::Options::Options() isHighPriorityStream(false) {} template -std::shared_ptr ProcessGroupNCCL::collective( +c10::intrusive_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, @@ -1114,7 +1114,7 @@ std::shared_ptr ProcessGroupNCCL::collective( } template -std::shared_ptr ProcessGroupNCCL::pointToPoint( +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensors, Fn fn, int peer, @@ -1186,7 +1186,7 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( } template -std::shared_ptr ProcessGroupNCCL::collective( +c10::intrusive_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, @@ -1203,7 +1203,7 @@ std::shared_ptr ProcessGroupNCCL::collective( } template -std::shared_ptr ProcessGroupNCCL::pointToPoint( +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensor, Fn fn, int peer, @@ -1217,7 +1217,7 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( [](std::vector&) {}); } -std::shared_ptr ProcessGroupNCCL::allreduce( +c10::intrusive_ptr ProcessGroupNCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { check_gpu_tensors(tensors); @@ -1242,14 +1242,14 @@ std::shared_ptr ProcessGroupNCCL::allreduce( "nccl:all_reduce"); } -std::shared_ptr ProcessGroupNCCL::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with NCCL"); } -std::shared_ptr ProcessGroupNCCL::broadcast( +c10::intrusive_ptr ProcessGroupNCCL::broadcast( std::vector& tensors, const BroadcastOptions& opts) { check_gpu_tensors(tensors); @@ -1274,7 +1274,7 @@ std::shared_ptr ProcessGroupNCCL::broadcast( "nccl:broadcast"); } -std::shared_ptr ProcessGroupNCCL::reduce( +c10::intrusive_ptr ProcessGroupNCCL::reduce( std::vector& tensors, const ReduceOptions& opts) { check_gpu_tensors(tensors); @@ -1301,7 +1301,7 @@ std::shared_ptr ProcessGroupNCCL::reduce( "nccl:reduce"); } -std::shared_ptr ProcessGroupNCCL::allgather( +c10::intrusive_ptr ProcessGroupNCCL::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -1346,7 +1346,7 @@ std::shared_ptr ProcessGroupNCCL::allgather( "nccl:all_gather"); } -std::shared_ptr ProcessGroupNCCL::allgather_coalesced( +c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -1354,7 +1354,7 @@ std::shared_ptr ProcessGroupNCCL::allgather_coalesced( "ProcessGroupNCCL does not support allgather_coalesced"); } -std::shared_ptr ProcessGroupNCCL::reduce_scatter( +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { @@ -1400,7 +1400,7 @@ std::shared_ptr ProcessGroupNCCL::reduce_scatter( "nccl:reduce_scatter"); } -std::shared_ptr ProcessGroupNCCL::barrier( +c10::intrusive_ptr ProcessGroupNCCL::barrier( const BarrierOptions& opts) { std::vector devices; if (usedDeviceIdxs_.empty()) { @@ -1441,7 +1441,7 @@ std::shared_ptr ProcessGroupNCCL::barrier( } #ifdef ENABLE_NCCL_P2P_SUPPORT -std::shared_ptr ProcessGroupNCCL::alltoall_base( +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -1512,7 +1512,7 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( } } -std::shared_ptr ProcessGroupNCCL::send( +c10::intrusive_ptr ProcessGroupNCCL::send( std::vector& tensors, int dstRank, int /* unused */) { @@ -1531,7 +1531,7 @@ std::shared_ptr ProcessGroupNCCL::send( return ret; } -std::shared_ptr ProcessGroupNCCL::recv( +c10::intrusive_ptr ProcessGroupNCCL::recv( std::vector& tensors, int srcRank, int /* unused */) { @@ -1550,7 +1550,7 @@ std::shared_ptr ProcessGroupNCCL::recv( return ret; } #else -std::shared_ptr ProcessGroupNCCL::alltoall_base( +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& /* unused */, at::Tensor& /* unused */, std::vector& /* unused */, @@ -1560,7 +1560,7 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } -std::shared_ptr ProcessGroupNCCL::send( +c10::intrusive_ptr ProcessGroupNCCL::send( std::vector& /* unused */, int /* unused */, int /* unused */) { @@ -1568,7 +1568,7 @@ std::shared_ptr ProcessGroupNCCL::send( "ProcessGroupNCCL only supports send for NCCL lib version >= 2.7.0"); } -std::shared_ptr ProcessGroupNCCL::recv( +c10::intrusive_ptr ProcessGroupNCCL::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { @@ -1591,34 +1591,34 @@ void ProcessGroupNCCL::groupEnd() { --ncclActiveGroupCounter_; } -std::shared_ptr ProcessGroupNCCL::alltoall( +c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& /* unused */, std::vector& /* unused */, const AllToAllOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support alltoall"); } -std::shared_ptr ProcessGroupNCCL::gather( +c10::intrusive_ptr ProcessGroupNCCL::gather( std::vector>& /* unused */, std::vector& /* unused */, const GatherOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support gather"); } -std::shared_ptr ProcessGroupNCCL::scatter( +c10::intrusive_ptr ProcessGroupNCCL::scatter( std::vector& /* unused */, std::vector>& /* unused */, const ScatterOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support scatter"); } -std::shared_ptr ProcessGroupNCCL::recvAnysource( +c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( std::vector& /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support recvAnysource"); } -std::shared_ptr ProcessGroupNCCL::allgather_base( +c10::intrusive_ptr ProcessGroupNCCL::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 1520604629f2..59f06fda1ec1 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -65,7 +65,7 @@ constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING"; class ProcessGroupNCCL : public ProcessGroup { public: class WorkNCCL : public ProcessGroup::Work, - public std::enable_shared_from_this { + public std::enable_shared_from_this { public: // Constructor takes a list of CUDA devices WorkNCCL(const std::vector& devices, int rank, OpType opType, const char* profilingTitle = nullptr); @@ -411,64 +411,64 @@ class ProcessGroupNCCL : public ProcessGroup { virtual ~ProcessGroupNCCL(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr alltoall( + c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; @@ -478,17 +478,17 @@ class ProcessGroupNCCL : public ProcessGroup { static void groupEnd(); // Unsupported Ops - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; @@ -515,7 +515,7 @@ class ProcessGroupNCCL : public ProcessGroup { virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms); - virtual std::shared_ptr initWork( + virtual c10::intrusive_ptr initWork( std::vector devices, int rank, OpType opType, @@ -529,14 +529,14 @@ class ProcessGroupNCCL : public ProcessGroup { // ncclComm_t, at::cuda::CUDAStream&); // void {pre,post}(std::vector); template - std::shared_ptr collective( + c10::intrusive_ptr collective( std::vector& input, std::vector& output, Fn fn, OpType opType, const char* profilingTitle = nullptr); template - std::shared_ptr collective( + c10::intrusive_ptr collective( std::vector& input, std::vector& output, Fn fn, @@ -549,13 +549,13 @@ class ProcessGroupNCCL : public ProcessGroup { // primitives. It is the same structure as the helper used for collective // communicaiton primitives. template - std::shared_ptr pointToPoint( + c10::intrusive_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, OpType opType); template - std::shared_ptr pointToPoint( + c10::intrusive_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, @@ -664,7 +664,7 @@ class ProcessGroupNCCL : public ProcessGroup { std::list workMetaList_; // Add Work Pointer to workVector - void workEnqueue(std::shared_ptr); + void workEnqueue(c10::intrusive_ptr); // The CUDA steams used by NCCL kernels std::unordered_map> diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.cpp b/torch/lib/c10d/ProcessGroupRoundRobin.cpp index 032f63c320f5..c77188577a62 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.cpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.cpp @@ -17,66 +17,66 @@ ProcessGroupRoundRobin::ProcessGroupRoundRobin( ProcessGroupRoundRobin::~ProcessGroupRoundRobin() {} -std::shared_ptr ProcessGroupRoundRobin::broadcast( +c10::intrusive_ptr ProcessGroupRoundRobin::broadcast( std::vector& tensors, const BroadcastOptions& opts) { return next()->broadcast(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allreduce( +c10::intrusive_ptr ProcessGroupRoundRobin::allreduce( std::vector& tensors, const AllreduceOptions& opts) { return next()->allreduce(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupRoundRobin::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { return next()->allreduce_coalesced(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::reduce( +c10::intrusive_ptr ProcessGroupRoundRobin::reduce( std::vector& tensors, const ReduceOptions& opts) { return next()->reduce(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allgather( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { return next()->allgather(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::allgather_coalesced( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts) { return next()->allgather(outputTensorLists, inputTensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::gather( +c10::intrusive_ptr ProcessGroupRoundRobin::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { return next()->gather(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::scatter( +c10::intrusive_ptr ProcessGroupRoundRobin::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { return next()->scatter(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::reduce_scatter( +c10::intrusive_ptr ProcessGroupRoundRobin::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { return next()->reduce_scatter(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::alltoall_base( +c10::intrusive_ptr ProcessGroupRoundRobin::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -86,27 +86,27 @@ std::shared_ptr ProcessGroupRoundRobin::alltoall_base( outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts); }; -std::shared_ptr ProcessGroupRoundRobin::send( +c10::intrusive_ptr ProcessGroupRoundRobin::send( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support send"); }; -std::shared_ptr ProcessGroupRoundRobin::recv( +c10::intrusive_ptr ProcessGroupRoundRobin::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -std::shared_ptr ProcessGroupRoundRobin::recvAnysource( +c10::intrusive_ptr ProcessGroupRoundRobin::recvAnysource( std::vector& /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -std::shared_ptr ProcessGroupRoundRobin::barrier( +c10::intrusive_ptr ProcessGroupRoundRobin::barrier( const BarrierOptions& /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support barrier"); }; @@ -120,7 +120,7 @@ const std::shared_ptr& ProcessGroupRoundRobin::next() { return processGroup; } -std::shared_ptr ProcessGroupRoundRobin::allgather_base( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.hpp b/torch/lib/c10d/ProcessGroupRoundRobin.hpp index bbbd0a1c756b..62d59ef18ce5 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.hpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.hpp @@ -25,75 +25,75 @@ class ProcessGroupRoundRobin final : public ProcessGroup { ~ProcessGroupRoundRobin() override; - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; private: diff --git a/torch/lib/c10d/comm.cpp b/torch/lib/c10d/comm.cpp index a8628e0c942e..5ef88f058aca 100644 --- a/torch/lib/c10d/comm.cpp +++ b/torch/lib/c10d/comm.cpp @@ -45,8 +45,10 @@ class BroadcastWork { // because c10d::ProcessGroup::broadcast takes a vector argument. std::vector flat_tensor_; + private: + // The broadcast work that is kicked off upon construction. - std::shared_ptr work_; + c10::intrusive_ptr work_; }; } // namespace diff --git a/torch/lib/c10d/example/allreduce.cpp b/torch/lib/c10d/example/allreduce.cpp index 76d6a5588f7e..3de7447d092a 100644 --- a/torch/lib/c10d/example/allreduce.cpp +++ b/torch/lib/c10d/example/allreduce.cpp @@ -19,7 +19,7 @@ int main(int argc, char** argv) { } // Kick off work - std::vector> pending; + std::vector> pending; for (auto i = 0; i < ntensors; i++) { std::vector tmp = {tensors[i]}; pending.push_back(pg.allreduce(tmp)); diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index c05ce685bb7d..c5ee54a9ee8e 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -472,7 +472,7 @@ std::vector> Reducer::get_bucket_tensors() const { } void Reducer::set_forward_pass_work_handle( - std::shared_ptr forwardPassWorkHandle, + c10::intrusive_ptr forwardPassWorkHandle, bool useStaticWorldSize) { std::lock_guard lock(mutex_); forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle); diff --git a/torch/lib/c10d/reducer.hpp b/torch/lib/c10d/reducer.hpp index 4874f0dd8703..e0fe0004f88e 100644 --- a/torch/lib/c10d/reducer.hpp +++ b/torch/lib/c10d/reducer.hpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -96,7 +97,7 @@ class Reducer { // Creates and sets ForwardPassWorkHandle given a ProcessGroup::Work and the // corresponding tensor being reduced. void set_forward_pass_work_handle( - std::shared_ptr forwardPassWorkHandle, + c10::intrusive_ptr forwardPassWorkHandle, bool useStaticWorldSize); // Retrieve on-device tensors used to track locally unused parameters. For @@ -158,7 +159,7 @@ class Reducer { bool local_used_maps_reduced_; // Work handle for allreduce on local_used_maps_ - std::shared_ptr local_used_work_; + c10::intrusive_ptr local_used_work_; void verify_replicas_within_process(); @@ -282,7 +283,7 @@ class Reducer { size_t pending; // Keep work handle around when this set of buckets is being reduced. - std::shared_ptr work; + c10::intrusive_ptr work; // Keep future work handle around if DDP comm hook is registered. c10::intrusive_ptr future_work; @@ -340,7 +341,7 @@ class Reducer { // A struct containing work handle and tensor for allreduce scheduled in // forward pass, if applicable. struct ForwardPassAllreduceWork { - std::shared_ptr workHandle; + c10::intrusive_ptr workHandle; at::Tensor resultTensor; // whether we should divide by the initial world_size or the no. of // remaining DDP ranks. diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index 92dede9a573e..1363a842eab3 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -93,7 +93,7 @@ class AsyncInputIsOutputTest : public AsyncTest { } } - void wait(std::shared_ptr& work) { + void wait(c10::intrusive_ptr& work) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(); } @@ -130,7 +130,7 @@ class AsyncAllreduceTest : public AsyncInputIsOutputTest { AsyncAllreduceTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -156,7 +156,7 @@ class AsyncBroadcastTest : public AsyncInputIsOutputTest { AsyncBroadcastTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -185,7 +185,7 @@ void runAsyncAllreduceTest( size_t numProcesses = 4, size_t numTensors = 2) { auto tests = initialize(path, numProcesses, numTensors); - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(); } @@ -229,7 +229,7 @@ void runAsyncBroadcastTest( // Try every permutation of root rank and root tensor for (size_t rootRank = 0; rootRank < numProcesses; rootRank++) { for (size_t rootTensor = 0; rootTensor < numTensors; rootTensor++) { - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(rootRank, rootTensor); } diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index da4f9b5fc106..de993a1110b4 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -44,7 +44,7 @@ class SignalTest { }); } - std::shared_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { + c10::intrusive_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { auto store = std::make_shared<::c10d::FileStore>(path_, size); ::c10d::ProcessGroupGloo::Options options; @@ -62,7 +62,7 @@ class SignalTest { }; // Loop until an exception happens - std::shared_ptr<::c10d::ProcessGroup::Work> work; + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work; while (true) { work = pg.allreduce(tensors); try { @@ -82,7 +82,7 @@ class SignalTest { Semaphore sem_; }; -std::shared_ptr<::c10d::ProcessGroup::Work> testSignal( +c10::intrusive_ptr<::c10d::ProcessGroup::Work> testSignal( const std::string& path, int signal) { Fork fork; @@ -107,7 +107,7 @@ class ProcessGroupGlooDelayed : public ::c10d::ProcessGroupGloo { Options options) : ProcessGroupGloo(store, rank, size, options) {} - std::shared_ptr<::c10d::ProcessGroup::Work> send( + c10::intrusive_ptr<::c10d::ProcessGroup::Work> send( std::vector& tensors, int dstRank, int tag) override { @@ -200,7 +200,7 @@ void testAllreduce(const std::string& path, const at::DeviceType b) { } // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().allreduce(inputs[i]); } @@ -250,7 +250,7 @@ void testBroadcast(const std::string& path, const at::DeviceType b) { options.rootTensor = j; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().broadcast(inputs[i], options); } @@ -316,7 +316,7 @@ void testAlltoall(const std::string& path, const at::DeviceType b) { }; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto rank = 0; rank < size; rank++) { work[rank] = tests[rank].getProcessGroup().alltoall_base( outputs[rank], inputs[rank], outputSplits[rank], inputSplits[rank]); @@ -349,7 +349,7 @@ void testBarrier(const std::string& path) { auto tests = CollectiveTest::initialize(path, size); // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().barrier(); } diff --git a/torch/lib/c10d/test/ProcessGroupMPITest.cpp b/torch/lib/c10d/test/ProcessGroupMPITest.cpp index 3f5a9e4cf331..6c60b3d6742d 100644 --- a/torch/lib/c10d/test/ProcessGroupMPITest.cpp +++ b/torch/lib/c10d/test/ProcessGroupMPITest.cpp @@ -14,7 +14,7 @@ // Wait for work to complete void waitWork( std::shared_ptr pg, - std::vector> works) { + std::vector> works) { for (auto& work : works) { try { work->wait(); @@ -34,10 +34,11 @@ void testAllreduce(int iter = 1000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->allreduce(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + pg->allreduce(tensors); works.push_back(std::move(work)); } @@ -73,10 +74,11 @@ void testBroadcast(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->broadcast(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + pg->broadcast(tensors); works.push_back(std::move(work)); } @@ -104,10 +106,10 @@ void testReduce(int iter = 10000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); works.push_back(std::move(work)); } @@ -150,10 +152,10 @@ void testAllgather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->allgather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -198,10 +200,10 @@ void testGather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->gather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -249,10 +251,10 @@ void testScatter(int iter = 1) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->scatter(allTensors[i], allInputTensors[i]); works.push_back(std::move(work)); } @@ -289,27 +291,27 @@ void testSendRecv(bool recvAnysource, int iter = 10000) { } if (rank == 0) { - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->send(tensors, 1, 0); works.push_back(std::move(work)); } waitWork(pg, works); } if (rank == 1) { - std::vector> works; + std::vector> works; std::vector srcRanks(allTensors.size(), -1); size_t i = 0; for (auto& tensors : allTensors) { // Kick off work if (!recvAnysource) { - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->recv(tensors, 0, 0); works.push_back(std::move(work)); } else { - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->recvAnysource(tensors, 0); works.push_back(std::move(work)); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index e906702a889d..f1348922e126 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -56,12 +56,12 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis); } - std::shared_ptr initWork( + c10::intrusive_ptr initWork( std::vector devices, int rank, c10d::OpType opType, const char* profilingTitle) override { - return std::make_shared( + return c10::make_intrusive( devices, simulate_error_, rank, opType); } @@ -113,12 +113,12 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} - std::shared_ptr initWork( + c10::intrusive_ptr initWork( std::vector devices, int rank, c10d::OpType opType, const char* profilingTitle) override { - return std::make_shared( + return c10::make_intrusive( devices, set_timedout_error_, rank, opType); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index 92b477fae7de..efa96312aba0 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -80,7 +80,7 @@ class NCCLTest : public NCCLTestBase { } void wait( - std::shared_ptr& work, + c10::intrusive_ptr& work, std::chrono::milliseconds timeout = kNoTimeout) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(timeout); @@ -166,7 +166,7 @@ class AllreduceNCCLTest : public NCCLTest { AllreduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -189,7 +189,7 @@ class BroadcastNCCLTest : public NCCLTest { BroadcastNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -208,7 +208,7 @@ class ReduceNCCLTest : public NCCLTest { ReduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -227,7 +227,7 @@ class AllgatherNCCLTest : public NCCLTest { AllgatherNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -242,7 +242,7 @@ struct ReduceScatterNCCLTest : NCCLTest { ReduceScatterNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_);