Skip to content

Commit

Permalink
[c10d] switch ProcessGroup to be managed by intrusive_ptr (#47343)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47343

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D24723418

Pulled By: wanchaol

fbshipit-source-id: 0463819b96c53b12bdbb3905431110d7b21beb77
  • Loading branch information
wanchaol authored and facebook-github-bot committed Nov 12, 2020
1 parent 859e054 commit 553cccc
Show file tree
Hide file tree
Showing 28 changed files with 166 additions and 167 deletions.
4 changes: 2 additions & 2 deletions test/cpp/rpc/test_e2e_process_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class TestE2EProcessGroup : public TestE2EBase {
options.timeout = rpcTimeout;

// Initialize server rpc agent.
auto pg =
std::make_shared<c10d::ProcessGroupGloo>(store, 0, numWorkers, options);
auto pg = c10::make_intrusive<c10d::ProcessGroupGloo>(
store, 0, numWorkers, options);

rpcAgent = std::make_shared<ProcessGroupAgent>(
"worker",
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/rpc/test_e2e_tensorpipe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class TestE2ETensorPipe : public TestE2EBase {
float rpcTimeout = 30;

// Initialize server rpc agent.
auto pg =
std::make_shared<c10d::ProcessGroupGloo>(store, 0, numWorkers, options);
auto pg = c10::make_intrusive<c10d::ProcessGroupGloo>(
store, 0, numWorkers, options);

TensorPipeRpcBackendOptions opts(
/*numWorkerThreads=*/std::max(16U, std::thread::hardware_concurrency()),
Expand Down
4 changes: 2 additions & 2 deletions test/cpp_extensions/cpp_c10d_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupTest::recvAnysource(
throw std::runtime_error("ProcessGroupTest does not support recvAnysource");
}

std::shared_ptr<ProcessGroup> ProcessGroupTest::createProcessGroupTest(
c10::intrusive_ptr<ProcessGroup> ProcessGroupTest::createProcessGroupTest(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::duration<float>& timeout) {
return std::make_shared<ProcessGroupTest>(rank, size);
return c10::make_intrusive<ProcessGroupTest>(rank, size);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down
8 changes: 4 additions & 4 deletions test/cpp_extensions/cpp_c10d_extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,19 @@ class ProcessGroupTest : public ProcessGroup {
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag);
int tag) override;

c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag);
int tag) override;

c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensor,
int tag);
int tag) override;

// Create a new ProcessGroupTest instance
static std::shared_ptr<ProcessGroup> createProcessGroupTest(
static c10::intrusive_ptr<ProcessGroup> createProcessGroupTest(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
Expand Down
25 changes: 12 additions & 13 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
py::init<
std::vector<std::vector<torch::autograd::Variable>>,
std::vector<std::vector<size_t>>,
std::shared_ptr<::c10d::ProcessGroup>,
c10::intrusive_ptr<::c10d::ProcessGroup>,
std::vector<std::vector<bool>>,
int64_t,
bool,
Expand Down Expand Up @@ -642,7 +642,7 @@ that adds a prefix to each key inserted to the store.
.def(py::init<const std::string&, c10::intrusive_ptr<::c10d::Store>>());

auto processGroup =
shared_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup")
intrusive_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup")
.def("rank", &::c10d::ProcessGroup::getRank)
.def("size", &::c10d::ProcessGroup::getSize)

Expand Down Expand Up @@ -907,21 +907,21 @@ that adds a prefix to each key inserted to the store.
#ifndef _WIN32
module.def(
"_round_robin_process_groups",
[](std::vector<std::shared_ptr<::c10d::ProcessGroup>> processGroups)
-> std::shared_ptr<::c10d::ProcessGroup> {
[](std::vector<c10::intrusive_ptr<::c10d::ProcessGroup>> processGroups)
-> c10::intrusive_ptr<::c10d::ProcessGroup> {
if (processGroups.size() == 0) {
throw std::invalid_argument("Specify at least 1 process group");
}
const auto& first = processGroups.front();
return std::make_shared<::c10d::ProcessGroupRoundRobin>(
return c10::make_intrusive<::c10d::ProcessGroupRoundRobin>(
first->getRank(), first->getSize(), std::move(processGroups));
},
py::arg("process_groups"),
py::call_guard<py::gil_scoped_release>());
#endif

#ifdef USE_C10D_GLOO
auto processGroupGloo = shared_ptr_class_<::c10d::ProcessGroupGloo>(
auto processGroupGloo = intrusive_ptr_class_<::c10d::ProcessGroupGloo>(
module, "ProcessGroupGloo", processGroup);

shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device");
Expand Down Expand Up @@ -981,7 +981,7 @@ that adds a prefix to each key inserted to the store.

options.timeout = timeout;
options.threads = options.devices.size() * 2;
return std::make_shared<::c10d::ProcessGroupGloo>(
return c10::make_intrusive<::c10d::ProcessGroupGloo>(
store, rank, size, options);
}),
py::arg("store"),
Expand All @@ -993,15 +993,14 @@ that adds a prefix to each key inserted to the store.

#ifdef USE_C10D_NCCL
auto processGroupNCCL =
shared_ptr_class_<::c10d::ProcessGroupNCCL>(
intrusive_ptr_class_<::c10d::ProcessGroupNCCL>(
module, "ProcessGroupNCCL", processGroup)
.def(
py::init<
const c10::intrusive_ptr<::c10d::Store>&,
int,
int,
const c10::intrusive_ptr<
::c10d::ProcessGroupNCCL::Options>&>(),
c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options>>(),
py::call_guard<py::gil_scoped_release>())
.def(
py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
Expand All @@ -1011,7 +1010,7 @@ that adds a prefix to each key inserted to the store.
auto options = ::c10d::ProcessGroupNCCL::Options::create();
options->isHighPriorityStream = false;
options->opTimeout = timeout;
return std::make_shared<::c10d::ProcessGroupNCCL>(
return c10::make_intrusive<::c10d::ProcessGroupNCCL>(
store, rank, size, options);
}),
py::arg("store"),
Expand All @@ -1036,7 +1035,7 @@ that adds a prefix to each key inserted to the store.
#endif

#ifdef USE_C10D_MPI
auto processGroupMPI = shared_ptr_class_<::c10d::ProcessGroupMPI>(
auto processGroupMPI = intrusive_ptr_class_<::c10d::ProcessGroupMPI>(
module, "ProcessGroupMPI", processGroup);

// Define static create function instead of a constructor, because
Expand Down Expand Up @@ -1149,7 +1148,7 @@ that adds a prefix to each key inserted to the store.
// Define a lambda such that the pybind11 prototype can take a std::vector
// for the tensor list argument, but still pass it to the underlying
// function as a c10::ArrayRef.
[](std::shared_ptr<::c10d::ProcessGroup> process_group,
[](c10::intrusive_ptr<::c10d::ProcessGroup> process_group,
std::vector<at::Tensor> tensors, // NOLINT
size_t buffer_size,
int rank) {
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/distributed/rpc/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {

shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent)
.def(py::init([](std::string workerName,
const std::shared_ptr<::c10d::ProcessGroup>& pg,
const c10::intrusive_ptr<::c10d::ProcessGroup>& pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout) {
return std::make_unique<ProcessGroupAgent>(
Expand Down Expand Up @@ -580,7 +580,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
std::string selfName,
worker_id_t selfId,
int worldSize,
std::shared_ptr<::c10d::ProcessGroup> processGroup,
c10::intrusive_ptr<::c10d::ProcessGroup> processGroup,
TensorPipeRpcBackendOptions opts) {
return std::make_shared<TensorPipeAgent>(
store,
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/rpc/process_group_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void ProcessGroupAgent::collectNames() {

ProcessGroupAgent::ProcessGroupAgent(
std::string workerName,
std::shared_ptr<c10d::ProcessGroup> pg,
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
std::unique_ptr<RequestCallback> cb)
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/distributed/rpc/process_group_agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
public:
ProcessGroupAgent(
std::string workerName,
std::shared_ptr<c10d::ProcessGroup> pg,
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
std::unique_ptr<RequestCallback> cb);
Expand Down Expand Up @@ -209,7 +209,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
return ++nextId_;
}

std::shared_ptr<c10d::ProcessGroup> pg_;
c10::intrusive_ptr<::c10d::ProcessGroup> pg_;
// worker name -> rank
std::unordered_map<std::string, worker_id_t> nameMap_;
std::vector<WorkerInfo> allWorkerInfo_;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/rpc/tensorpipe_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ TensorPipeAgent::TensorPipeAgent(
std::string selfName,
worker_id_t selfId,
int worldSize,
std::shared_ptr<c10d::ProcessGroup> processGroup,
c10::intrusive_ptr<::c10d::ProcessGroup> processGroup,
TensorPipeRpcBackendOptions opts,
std::unique_ptr<RequestCallback> cb)
: RpcAgent(
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/distributed/rpc/tensorpipe_agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class TensorPipeAgent : public RpcAgent {
std::string selfName,
worker_id_t selfId,
int worldSize,
std::shared_ptr<c10d::ProcessGroup> processGroup,
c10::intrusive_ptr<::c10d::ProcessGroup> processGroup,
TensorPipeRpcBackendOptions opts,
std::unique_ptr<RequestCallback> cb);

Expand Down Expand Up @@ -283,7 +283,7 @@ class TensorPipeAgent : public RpcAgent {
// The join method is required to behave like a barrier and perform collective
// operations. For simplicity and reliability, we offload this to a process
// group, but probably one day we might want to re-implement them using RPCs.
const std::shared_ptr<c10d::ProcessGroup> processGroup_;
const c10::intrusive_ptr<::c10d::ProcessGroup> processGroup_;

mutable std::mutex mutex_;
uint64_t nextMessageID_{0};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ std::string fromVec(const std::vector<char>& vec) {

FaultyProcessGroupAgent::FaultyProcessGroupAgent(
std::string workerName,
std::shared_ptr<c10d::ProcessGroup> pg,
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
const std::vector<std::string>& messagesToFail,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent {
public:
FaultyProcessGroupAgent(
std::string workerName,
std::shared_ptr<c10d::ProcessGroup> pg,
c10::intrusive_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
const std::vector<std::string>& messagesToFail,
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/rpc/testing/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
.def(
py::init<
std::string,
std::shared_ptr<::c10d::ProcessGroup>,
c10::intrusive_ptr<::c10d::ProcessGroup>,
int,
std::chrono::milliseconds,
const std::vector<std::string>&,
Expand Down
2 changes: 1 addition & 1 deletion torch/lib/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ bool isP2POp(OpType opType);
// process group to find each other (referred to as rendezvous from
// hereon)
//
class ProcessGroup {
class ProcessGroup : public torch::CustomClassHolder {
public:
// Please do not use ProcessGroup::Work API, it is going away, to be
// replaced by ivalue::Future.
Expand Down
6 changes: 3 additions & 3 deletions torch/lib/c10d/ProcessGroupMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void ProcessGroupMPI::initMPIOnce() {
});
}

std::shared_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
std::vector<int> ranks) {
// Once initialization
initMPIOnce();
Expand Down Expand Up @@ -238,10 +238,10 @@ std::shared_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
// process group instance. This is in line with the semantics of the
// other process group types.
if (groupComm == MPI_COMM_NULL) {
return std::shared_ptr<ProcessGroupMPI>();
return c10::intrusive_ptr<ProcessGroupMPI>();
}

return std::make_shared<ProcessGroupMPI>(rank, size, groupComm);
return c10::make_intrusive<ProcessGroupMPI>(rank, size, groupComm);
}

ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)
Expand Down
2 changes: 1 addition & 1 deletion torch/lib/c10d/ProcessGroupMPI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class ProcessGroupMPI : public ProcessGroup {
const BarrierOptions& opts = BarrierOptions()) override;

// Creating a new ProcessGroupMPI, will initiialize MPI if not initialized
static std::shared_ptr<ProcessGroupMPI> createProcessGroupMPI(
static c10::intrusive_ptr<ProcessGroupMPI> createProcessGroupMPI(
std::vector<int> ranks = {});

protected:
Expand Down
2 changes: 1 addition & 1 deletion torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
const c10::intrusive_ptr<Options>& options)
c10::intrusive_ptr<Options> options)
: ProcessGroup(rank, size),
store_(store),
ncclCommCounter_(0),
Expand Down
4 changes: 2 additions & 2 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ class ProcessGroupNCCL : public ProcessGroup {
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
const c10::intrusive_ptr<Options>& options = Options::create());
c10::intrusive_ptr<Options> options = Options::create());

// This constructor includes the deprecated `groupName` argument.
// If you have existing code that uses the `groupName`, you can replace
Expand All @@ -416,7 +416,7 @@ class ProcessGroupNCCL : public ProcessGroup {
int rank,
int size,
const std::string& groupName,
const c10::intrusive_ptr<Options>& options = Options::create())
c10::intrusive_ptr<Options> options = Options::create())
: ProcessGroupNCCL(store, rank, size, options) {}

virtual ~ProcessGroupNCCL();
Expand Down
4 changes: 2 additions & 2 deletions torch/lib/c10d/ProcessGroupRoundRobin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace c10d {
ProcessGroupRoundRobin::ProcessGroupRoundRobin(
int rank,
int size,
std::vector<std::shared_ptr<ProcessGroup>> processGroups)
std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups)
: ProcessGroup(rank, size), processGroups_(std::move(processGroups)) {
TORCH_CHECK(processGroups_.size() >= 1);
for (const auto& processGroup : processGroups_) {
Expand Down Expand Up @@ -111,7 +111,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupRoundRobin::barrier(
throw std::runtime_error("ProcessGroupRoundRobin does not support barrier");
};

const std::shared_ptr<ProcessGroup>& ProcessGroupRoundRobin::next() {
const c10::intrusive_ptr<ProcessGroup>& ProcessGroupRoundRobin::next() {
auto& processGroup = *iterator_;
iterator_++;
if (iterator_ == processGroups_.end()) {
Expand Down
8 changes: 4 additions & 4 deletions torch/lib/c10d/ProcessGroupRoundRobin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ProcessGroupRoundRobin final : public ProcessGroup {
explicit ProcessGroupRoundRobin(
int rank,
int size,
std::vector<std::shared_ptr<ProcessGroup>> processGroups);
std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups);

~ProcessGroupRoundRobin() override;

Expand Down Expand Up @@ -97,11 +97,11 @@ class ProcessGroupRoundRobin final : public ProcessGroup {
const BarrierOptions& opts = BarrierOptions()) override;

private:
std::vector<std::shared_ptr<ProcessGroup>> processGroups_;
std::vector<std::shared_ptr<ProcessGroup>>::const_iterator iterator_;
std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups_;
std::vector<c10::intrusive_ptr<ProcessGroup>>::const_iterator iterator_;

// Returns the next ProcessGroup to use.
const std::shared_ptr<ProcessGroup>& next();
const c10::intrusive_ptr<ProcessGroup>& next();
};

} // namespace c10d
5 changes: 2 additions & 3 deletions torch/lib/c10d/comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace {
class BroadcastWork {
public:
BroadcastWork(
const std::shared_ptr<c10d::ProcessGroup>& process_group,
const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
std::vector<at::Tensor> bucket_tensors,
int root_rank = 0)
: bucket_tensors_(std::move(bucket_tensors)),
Expand Down Expand Up @@ -55,7 +55,7 @@ class BroadcastWork {

// Broadcast many tensors to all processes in the process group.
void broadcast_coalesced(
std::shared_ptr<c10d::ProcessGroup> process_group,
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
at::TensorList tensors,
size_t buffer_size,
int rank) {
Expand Down Expand Up @@ -87,5 +87,4 @@ void broadcast_coalesced(
}
}


} // namespace c10d
2 changes: 1 addition & 1 deletion torch/lib/c10d/comm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace c10d {

// Broadcast many tensors to all processes in the process group.
void broadcast_coalesced(
std::shared_ptr<c10d::ProcessGroup> process_group,
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
at::TensorList tensors,
size_t buffer_size,
int rank = 0);
Expand Down

0 comments on commit 553cccc

Please sign in to comment.