Skip to content

Commit

Permalink
[c10d] switch ProcessGroup::Work to be managed by intrusive_ptr (#44046)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #44046

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D23632280

Pulled By: wanchaol

fbshipit-source-id: 0a4642a8ffabdd26c52c1baabfa30c0f446c3c85
  • Loading branch information
wanchaol authored and facebook-github-bot committed Nov 11, 2020
1 parent cbf439c commit 0650a61
Show file tree
Hide file tree
Showing 24 changed files with 295 additions and 287 deletions.
32 changes: 16 additions & 16 deletions test/cpp_extensions/cpp_c10d_extension.cpp
Expand Up @@ -23,85 +23,85 @@ ProcessGroupTest::ProcessGroupTest(int rank, int size)

ProcessGroupTest::~ProcessGroupTest() {}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::broadcast(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
return std::make_shared<ProcessGroupTest::WorkTest>();
return c10::make_intrusive<ProcessGroupTest::WorkTest>();
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
return std::make_shared<ProcessGroupTest::WorkTest>();
return c10::make_intrusive<ProcessGroupTest::WorkTest>();
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce_coalesced(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::reduce(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support reduce");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allgather(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support allgather");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allgather_base(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support allgather_base");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::barrier(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::barrier(
const BarrierOptions& opts) {
return std::make_shared<ProcessGroupTest::WorkTest>();
return c10::make_intrusive<ProcessGroupTest::WorkTest>();
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::gather(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support gather");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::scatter(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support scatter");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::reduce_scatter(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
throw std::runtime_error("ProcessGroupTest does not support reduce_scatter");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::send(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) {
throw std::runtime_error("ProcessGroupTest does not support send");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::recv(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) {
throw std::runtime_error("ProcessGroupTest does not support recv");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::recvAnysource(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::recvAnysource(
std::vector<at::Tensor>& tensor,
int tag) {
throw std::runtime_error("ProcessGroupTest does not support recvAnysource");
Expand Down
26 changes: 13 additions & 13 deletions test/cpp_extensions/cpp_c10d_extension.hpp
Expand Up @@ -41,61 +41,61 @@ class ProcessGroupTest : public ProcessGroup {
explicit ProcessGroupTest(int rank = -1, int size = -1);
virtual ~ProcessGroupTest();

std::shared_ptr<ProcessGroup::Work> broadcast(
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;

std::shared_ptr<ProcessGroup::Work> allreduce(
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;

std::shared_ptr<ProcessGroup::Work> allreduce_coalesced(
c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override;

std::shared_ptr<ProcessGroup::Work> reduce(
c10::intrusive_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;

std::shared_ptr<ProcessGroup::Work> allgather(
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;

std::shared_ptr<ProcessGroup::Work> allgather_base(
c10::intrusive_ptr<ProcessGroup::Work> allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;

std::shared_ptr<ProcessGroup::Work> barrier(
c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;

std::shared_ptr<ProcessGroup::Work> gather(
c10::intrusive_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;

std::shared_ptr<ProcessGroup::Work> scatter(
c10::intrusive_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;

std::shared_ptr<ProcessGroup::Work> reduce_scatter(
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;

std::shared_ptr<ProcessGroup::Work> send(
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag);

std::shared_ptr<ProcessGroup::Work> recv(
c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag);

std::shared_ptr<ProcessGroup::Work> recvAnysource(
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensor,
int tag);

Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/distributed/c10d/init.cpp
@@ -1,5 +1,6 @@
#include <torch/csrc/python_headers.h>

#include <c10/util/intrusive_ptr.h>
#include <c10d/FileStore.hpp>
#ifndef _WIN32
#include <c10d/HashStore.hpp>
Expand Down Expand Up @@ -59,6 +60,8 @@ constexpr auto kDeprecationWarning =
"{} API is being deprecated, please ping "
"https://github.com/pytorch/pytorch/issues/46291 "
"if you see this warning";
template <typename T>
using intrusive_ptr_class_ = py::class_<T, c10::intrusive_ptr<T>>;

// PythonStore is a pybind11 trampoline class to allow a Python
// class to inherit from c10d.Store and implement its interface.
Expand Down Expand Up @@ -1045,7 +1048,7 @@ that adds a prefix to each key inserted to the store.
py::call_guard<py::gil_scoped_release>());
#endif

shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")
intrusive_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")
.def("is_completed", &::c10d::ProcessGroup::Work::isCompleted)
.def(
"is_success",
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/rpc/process_group_agent.cpp
Expand Up @@ -398,7 +398,7 @@ void ProcessGroupAgent::handleSend(const SendWork& work) {

// ProcessGroup is not thread-safe when sending with the same tag,
// hence the lock
std::vector<std::shared_ptr<c10d::ProcessGroup::Work>> pendingSends;
std::vector<c10::intrusive_ptr<c10d::ProcessGroup::Work>> pendingSends;
const auto dst = work.to_.id_;

// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/distributed/rpc/process_group_agent.h
Expand Up @@ -230,14 +230,14 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
// Lock and shared ptr to currently pending work, set in listenloop() and
// interruptible in shutdown().
std::mutex recvWorkMutex_;
std::shared_ptr<c10d::ProcessGroup::Work> recvWork_;
c10::intrusive_ptr<c10d::ProcessGroup::Work> recvWork_;
// Map of dst rank to current oustanding sends that we are waiting on. In the
// case of a call to ::shutdown() while we are still waiting on these sends,
// the pending sends contained in this map will be aborted, allowing the
// waiting thread to be unblocked.
std::unordered_map<
worker_id_t,
std::set<std::shared_ptr<c10d::ProcessGroup::Work>>>
std::set<c10::intrusive_ptr<c10d::ProcessGroup::Work>>>
currentPendingSends_;
// Lock to serialize access to the above map.
std::mutex pendingSendMutex_;
Expand Down
2 changes: 1 addition & 1 deletion torch/lib/c10d/ProcessGroup.cpp
Expand Up @@ -164,7 +164,7 @@ ProcessGroup::~ProcessGroup() {}

// This is introduced so that implementors of ProcessGroup would not need to
// have this implmentation.
std::shared_ptr<ProcessGroup::Work> ProcessGroup::allgather_coalesced(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroup::allgather_coalesced(
std::vector<std::vector<at::Tensor>>& /* usused */,
std::vector<at::Tensor>& /* usused */,
const AllgatherOptions& /* usused */) {
Expand Down
35 changes: 17 additions & 18 deletions torch/lib/c10d/ProcessGroup.hpp
Expand Up @@ -70,12 +70,11 @@ bool isP2POp(OpType opType);
//
class ProcessGroup {
public:

// Please do not use ProcessGroup::Work API, it is going away, to be
// replaced by ivalue::Future.
// Python binding for this class might change, please do not assume
// this will be bound using pybind.
class Work {
class Work : public torch::CustomClassHolder {
public:
Work(int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr);

Expand Down Expand Up @@ -171,33 +170,33 @@ class ProcessGroup {
return size_;
}

virtual std::shared_ptr<ProcessGroup::Work> broadcast(
virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) = 0;

virtual std::shared_ptr<ProcessGroup::Work> allreduce(
virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& data,
const AllreduceOptions& opts = AllreduceOptions()) = 0;

// This will be moved out of ProcessGroup, do not add dependencies on this
// function.
virtual std::shared_ptr<ProcessGroup::Work> allreduce_coalesced(
virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) = 0;

virtual std::shared_ptr<ProcessGroup::Work> reduce(
virtual c10::intrusive_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) = 0;

virtual std::shared_ptr<ProcessGroup::Work> allgather(
virtual c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) = 0;

// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
// is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE.
// For implementers of ProcessGroup API and advanced users only.
virtual std::shared_ptr<ProcessGroup::Work> allgather_base(
virtual c10::intrusive_ptr<ProcessGroup::Work> allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) = 0;
Expand All @@ -206,27 +205,27 @@ class ProcessGroup {
// * do not add dependencies on this function,
// * do not implement it in your ProcessGroup, implement allgather_base
// instead.
virtual std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
virtual c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions());

virtual std::shared_ptr<ProcessGroup::Work> gather(
virtual c10::intrusive_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) = 0;

virtual std::shared_ptr<ProcessGroup::Work> scatter(
virtual c10::intrusive_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) = 0;

virtual std::shared_ptr<ProcessGroup::Work> reduce_scatter(
virtual c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0;

virtual std::shared_ptr<ProcessGroup::Work> alltoall_base(
virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
Expand All @@ -235,28 +234,28 @@ class ProcessGroup {
throw std::runtime_error("ProcessGroup does not support alltoall");
}

virtual std::shared_ptr<ProcessGroup::Work> alltoall(
virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) {
throw std::runtime_error("ProcessGroup does not support alltoall");
}

virtual std::shared_ptr<ProcessGroup::Work> send(
virtual c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) = 0;

virtual std::shared_ptr<ProcessGroup::Work> recv(
virtual c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) = 0;

virtual std::shared_ptr<ProcessGroup::Work> recvAnysource(
virtual c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) = 0;

virtual std::shared_ptr<ProcessGroup::Work> barrier(
virtual c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) = 0;

protected:
Expand Down

0 comments on commit 0650a61

Please sign in to comment.