Skip to content

Commit

Permalink
Allow RPC to be initialized again after shutdown.
Browse files Browse the repository at this point in the history
This PR is addressing #39340
and allows users to initialize RPC again after shutdown. Major changes in the
PR include:

1. Change to DistAutogradContainer to support this.
2. Ensure PythonRpcHandler is reinitialized appropriately.
3. Use PrefixStore in RPC initialization to ensure each new `init_rpc` uses a
different prefix.

Differential Revision: [D22993909](https://our.internmc.facebook.com/intern/diff/D22993909/)

ghstack-source-id: 109412898
Pull Request resolved: #42723
  • Loading branch information
pritamdamania committed Aug 7, 2020
1 parent b852168 commit 109f590
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 22 deletions.
12 changes: 10 additions & 2 deletions torch/csrc/distributed/autograd/context/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,16 @@ DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {

auto& container = getInstanceInternal();
TORCH_CHECK(
!container.initialized_,
"Container is already initialized! Cannot initialize it twice!");
!container.initialized_ || (worker_id == container.worker_id_),
"Container is already initialized with worker_id: ",
container.worker_id_,
", cannot initialize with different worker_id: ",
worker_id);

if (container.initialized_) {
LOG(INFO) << "DistAutogradContainer is already initialized";
return container;
}

container.worker_id_ = worker_id;
container.next_context_id_ = static_cast<int64_t>(worker_id)
Expand Down
46 changes: 27 additions & 19 deletions torch/csrc/distributed/rpc/python_rpc_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,33 @@ void cleanupPyObj(py::object& obj) {

} // namespace

PythonRpcHandler::PythonRpcHandler() {
PROFILE_GIL_SCOPED_ACQUIRE;
py::object rpcInternal = py::module::import(kInternalModule);
py::object rpcApi = py::module::import("torch.distributed.rpc.api");
py::object rrefProxy = py::module::import("torch.distributed.rpc.rref_proxy");

pyRunFunction_ = getFunction(rpcInternal, "_run_function");
pySerialize_ = getFunction(rpcInternal, "serialize");
pyDeserialize_ = getFunction(rpcInternal, "deserialize");
pyHandleException_ = getFunction(rpcInternal, "_handle_exception");

rrefProxyFunctions_.rpcSync_ = getFunction(rpcApi, "rpc_sync");
rrefProxyFunctions_.rpcAsync_ = getFunction(rpcApi, "rpc_async");
rrefProxyFunctions_.remote_ = getFunction(rpcApi, "remote");
rrefProxyFunctions_.rrefProxyCtor_ = getFunction(rrefProxy, "RRefProxy");

jitCompilationUnit_ = torch::jit::get_python_cu();
typeParser_ = std::make_shared<jit::ScriptTypeParser>(
std::make_shared<PythonTypeResolver>());
void PythonRpcHandler::init() {
if (!initialized_) {
PROFILE_GIL_SCOPED_ACQUIRE;
py::object rpcInternal = py::module::import(kInternalModule);
py::object rpcApi = py::module::import("torch.distributed.rpc.api");
py::object rrefProxy =
py::module::import("torch.distributed.rpc.rref_proxy");

pyRunFunction_ = getFunction(rpcInternal, "_run_function");
pySerialize_ = getFunction(rpcInternal, "serialize");
pyDeserialize_ = getFunction(rpcInternal, "deserialize");
pyHandleException_ = getFunction(rpcInternal, "_handle_exception");

rrefProxyFunctions_.rpcSync_ = getFunction(rpcApi, "rpc_sync");
rrefProxyFunctions_.rpcAsync_ = getFunction(rpcApi, "rpc_async");
rrefProxyFunctions_.remote_ = getFunction(rpcApi, "remote");
rrefProxyFunctions_.rrefProxyCtor_ = getFunction(rrefProxy, "RRefProxy");

jitCompilationUnit_ = torch::jit::get_python_cu();
typeParser_ = std::make_shared<jit::ScriptTypeParser>(
std::make_shared<PythonTypeResolver>());
initialized_ = true;
}
}

PythonRpcHandler::PythonRpcHandler() : initialized_(false) {}

void PythonRpcHandler::cleanup() {
PROFILE_GIL_SCOPED_ACQUIRE;
cleanupPyObj(pyRunFunction_);
Expand All @@ -103,6 +109,7 @@ void PythonRpcHandler::cleanup() {

jitCompilationUnit_ = nullptr;
typeParser_ = nullptr;
initialized_ = false;
}

PythonRpcHandler& PythonRpcHandler::getInstance() {
Expand All @@ -117,6 +124,7 @@ PythonRpcHandler& PythonRpcHandler::getInstance() {
TORCH_INTERNAL_ASSERT(!PyGILState_Check());
// Leaky singleton to avoid module destructor race.
static PythonRpcHandler* handler = new PythonRpcHandler();
handler->init();
return *handler;
}

Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/distributed/rpc/python_rpc_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class PYBIND11_EXPORT PythonRpcHandler {
const RRefProxyFunctions& getRRefProxyFunctions() const;

private:
void init();
PythonRpcHandler();
~PythonRpcHandler() = default;

Expand Down Expand Up @@ -105,6 +106,9 @@ class PYBIND11_EXPORT PythonRpcHandler {
// jit type parser to parse type_str back to TypePtr for RRef type
// recovery when pickling and unpickling RRef
std::shared_ptr<jit::ScriptTypeParser> typeParser_;

// Indicates whether or not we have properly initialized the handler.
std::atomic<bool> initialized_;
};

} // namespace rpc
Expand Down
10 changes: 9 additions & 1 deletion torch/distributed/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import numbers
import sys

import torch
import torch.distributed as dist
import threading

_init_counter = 0
_init_counter_lock = threading.Lock()

def is_available():
return hasattr(torch._C, "_rpc_init")
Expand Down Expand Up @@ -81,6 +83,12 @@ def init_rpc(
)
store, _, _ = next(rendezvous_iterator)

# Use a PrefixStore to distinguish multiple invocations.
with _init_counter_lock:
global _init_counter
store = dist.PrefixStore(str(_init_counter), store)
_init_counter += 1

# Initialize autograd before RPC since _init_rpc_backend guarantees all
# processes sync via the store. If we initialize autograd after RPC,
# there could be a race where some nodes might have initialized autograd
Expand Down
31 changes: 31 additions & 0 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3434,6 +3434,37 @@ def test_wait_all_with_partial_exception(self):
with self.assertRaisesRegex(ValueError, "Expected error"):
ret = torch.futures.wait_all(futs)

@dist_init(setup_rpc=False)
def test_init_rpc_twice(self):
rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
rpc.shutdown()

initialize_pg(self.init_method, self.rank, self.world_size)
# Wait for all init to complete.
dist.barrier()

# Ensure rpc initialization works again.
rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)

# Verify RPCs work after re-init.
dst = worker_name((self.rank + 1) % self.world_size)
rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
rpc.rpc_sync(dst, foo_add, args=())

rpc.shutdown()


class FaultyAgentRpcTest(FaultyRpcAgentTestFixture):

Expand Down

0 comments on commit 109f590

Please sign in to comment.