Skip to content

Commit

Permalink
[jit] In RPC Server, handle TorchScript continuations asynchronously (p…
Browse files Browse the repository at this point in the history
…ytorch#34109)

Summary:
Pull Request resolved: pytorch#34109

This change adds glue to GraphExecutor to give the RPC server
access to the future-based Interpreter::runAsync() api.

Previously, if a server encounted a TorchScript continuation-based block
with fork/wait, it would simply block in the server thread until the handler
completed, since it uses the synchronous Interpreter::run() api.

With the ivalue::Future returned by the Interpreter, we can run the
TorchScript code asynchronously from c++ simply by connecting its
callback to the server callback.

We add test cases to cover the new logic, both rpc_async and remote.

ghstack-source-id: 101245438

Test Plan: buck test mode/dev-nosan caffe2/test/distributed/rpc/...

Differential Revision: D20194321

fbshipit-source-id: 16785ec5d9ed0b16cb1ffab0a9771a77de30fcb0
  • Loading branch information
jjlilley authored and facebook-github-bot committed Apr 1, 2020
1 parent e5746ee commit 8d64a38
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 51 deletions.
7 changes: 7 additions & 0 deletions aten/src/ATen/core/builtin_function.h
Expand Up @@ -29,6 +29,13 @@ struct BuiltinOpFunction : public Function {
callable_(stack);
}

c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) override {
run(stack);
auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type());
res->markCompleted(std::move(stack.front()));
return res;
}

at::IValue operator()(std::vector<at::IValue> stack, const Kwargs& kwargs)
override {
getSchema().checkAndNormalizeInputs(stack, kwargs);
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/function.h
Expand Up @@ -31,6 +31,8 @@ struct TORCH_API Function {

virtual void run(Stack&& stack) = 0;

virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) = 0;

virtual at::IValue operator()(
std::vector<at::IValue> stack,
const Kwargs& kwargs = Kwargs()) = 0;
Expand Down
120 changes: 84 additions & 36 deletions torch/csrc/distributed/rpc/request_callback_impl.cpp
Expand Up @@ -127,20 +127,40 @@ void RequestCallbackImpl::processRpc(
auto& stack = scriptCall.stackRef();
if (scriptCall.hasOp()) {
scriptCall.op()->getOperation()(stack);
} else {
PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(scriptCall.qualifiedName())
.run(stack);
TORCH_INTERNAL_ASSERT(
stack.size() == 1,
"Return value of a builtin operator or a "
"TorchScript function should be a single IValue, got a vector of "
"size ",
stack.size());
markComplete(
std::move(ScriptResp(std::move(stack.front()))).toMessage());
return;
}

TORCH_INTERNAL_ASSERT(
stack.size() == 1,
"Return value of a builtin operator or a "
"TorchScript function should be a single IValue, got a vector of "
"size ",
stack.size());
markComplete(std::move(ScriptResp(std::move(stack.front()))).toMessage());
// runAsync() starts in the calling thread, but may return an uncompleted
// future (though for non-async code, it will typically be completed).
// If it was async, our callback will typically be invoked by the
// continuation on an at::launch() thread.
auto jitFuture = PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(scriptCall.qualifiedName())
.runAsync(stack);

if (jitFuture->completed()) {
markComplete(
std::move(ScriptResp(std::move(jitFuture->value()))).toMessage());
return;
}
jitFuture->addCallback([responseFuture, messageId, jitFuture]() {
try {
Message m = ScriptResp(std::move(jitFuture->value())).toMessage();
m.setId(messageId);
responseFuture->markCompleted(std::move(m));
} catch (const std::exception& e) {
responseFuture->setError(e.what());
}
});
return;
}
case MessageType::PYTHON_CALL: {
Expand Down Expand Up @@ -177,19 +197,37 @@ void RequestCallbackImpl::processRpc(
}

auto ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, returnType);
auto postProcessing = [rrefId, forkId, messageId, responseFuture]() {
if (rrefId != forkId) {
// Caller is a user and callee is the owner, add fork
//
// NB: rrefId == forkId is true if and only if calling remote to
// self. In that case both the caller and the callee will access
// the OwnerRRef. Hence, on the callee side (here), it should not
// call addForkOfOwner as it is not a fork. To allow callee to
// distinguish when this request is sent to self, the caller will
// set forkId using rrefId (OwnerRRef does not have a forkId
// anyway).
RRefContext::getInstance().addForkOfOwner(rrefId, forkId);
}
Message m = RemoteRet(rrefId, forkId).toMessage();
m.setId(messageId);
responseFuture->markCompleted(std::move(m));
};

// TODO: make this asynchronous
// scriptRemoteCall is only alive within this block, use reference to
// avoid copy
// avoid copy. If the underlying code runs with a continuation, runAsync()
// below will std::move the appropriate portion of the stack.
auto& stack = scriptRemoteCall.stackRef();
try {
if (scriptRemoteCall.hasOp()) {
if (scriptRemoteCall.hasOp()) {
try {
scriptRemoteCall.op()->getOperation()(stack);
} else {
PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(scriptRemoteCall.qualifiedName())
.run(stack);
} catch (const std::exception& e) {
// Don't throw in this call, but rather transfer the exception
// to the rref.
ownerRRef->setError(e.what());
postProcessing();
return;
}
TORCH_INTERNAL_ASSERT(
stack.size() == 1,
Expand All @@ -198,24 +236,34 @@ void RequestCallbackImpl::processRpc(
"size ",
stack.size());
ownerRRef->setValue(std::move(stack.front()));
} catch (const std::exception& e) {
// Don't throw in this call, but rather transfer the exception
// to the rref.
ownerRRef->setError(e.what());
postProcessing();
return;
}

if (rrefId != forkId) {
// Caller is a user and callee is the owner, add fork
//
// NB: rrefId == forkId is true if and only if calling remote to self.
// In that case both the caller and the callee will access the
// OwnerRRef. Hence, on the callee side (here), it should not call
// addForkOfOwner as it is not a fork. To allow callee to distinguish
// when this request is sent to self, the caller will set forkId using
// rrefId (OwnerRRef does not have a forkId anyway).
ctx.addForkOfOwner(rrefId, forkId);
c10::intrusive_ptr<c10::ivalue::Future> jitFuture;
try {
jitFuture = PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(scriptRemoteCall.qualifiedName())
.runAsync(stack);
if (jitFuture->completed()) { // short-cut.
ownerRRef->setValue(jitFuture->value());
postProcessing();
return;
}
} catch (const std::exception& e) {
ownerRRef->setError(e.what());
postProcessing();
return;
}
markComplete(RemoteRet(rrefId, forkId).toMessage());
jitFuture->addCallback([ownerRRef, postProcessing, jitFuture]() {
try {
ownerRRef->setValue(jitFuture->value());
} catch (const std::exception& e) {
ownerRRef->setError(e.what());
}
postProcessing();
});
return;
}
case MessageType::PYTHON_REMOTE_CALL: {
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/api/function_impl.cpp
Expand Up @@ -36,6 +36,10 @@ void GraphFunction::run(Stack&& stack) {
run(stack);
}

c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(Stack& stack) {
return get_executor().runAsync(stack);
}

IValue GraphFunction::operator()(
std::vector<IValue> stack,
const Kwargs& kwargs) {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/api/function_impl.h
Expand Up @@ -25,6 +25,8 @@ struct TORCH_API GraphFunction : public Function {

void run(Stack&& stack) override;

c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) override;

IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs())
override;

Expand Down
35 changes: 35 additions & 0 deletions torch/csrc/jit/runtime/graph_executor.cpp
Expand Up @@ -481,6 +481,37 @@ void GraphExecutorImplBase::run(Stack& stack) {
last_executed_optimized_graph = plan.graph;
}

c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(Stack& stack) {
TORCH_CHECK(
stack.size() >= num_inputs,
"expected ",
num_inputs,
" inputs, but got only ",
stack.size());

C10_LOG_API_USAGE_ONCE("torch.graph_executor.runAsync");
logging::getLogger()->addStatValue(
logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);

struct Frame {
explicit Frame(ExecutionPlan eplan)
: plan(std::move(eplan)),
state(plan.code) {
}
ExecutionPlan plan;
InterpreterState state;
};
auto frame = std::make_shared<Frame>(
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()));
auto res = frame->state.runAsync(stack);
last_executed_optimized_graph = frame->plan.graph;
if (!res->completed()) {
// If not completed, persist the Frame until complete.
res->addCallback([frame] {});
}
return res;
}

// a Graph can be created via tracing, or via a language-based frontend
// GraphExecutor runs it. It can run the same graph on many different sizes
// and different requires_grad states, and handles specializations for each
Expand Down Expand Up @@ -640,6 +671,10 @@ void GraphExecutor::run(Stack& inputs) {
return pImpl->run(inputs);
}

c10::intrusive_ptr<Future> GraphExecutor::runAsync(Stack& stack) {
return pImpl->runAsync(stack);
}

size_t GraphExecutor::getDefaultNumBailOuts() {
return getProfilingMode() ? getBailoutDepth().load() : 0;
}
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/runtime/graph_executor.h
Expand Up @@ -45,7 +45,10 @@ struct GraphExecutorImplBase;
struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph, std::string function_name);

void run(Stack& inputs);
c10::intrusive_ptr<Future> runAsync(Stack& stack);

// `remaining_bailout_depth` stands for the maximum number of profiled and
// specialized recompilations allowed for the current `GraphExecutor`. if
// remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/runtime/graph_executor_impl.h
Expand Up @@ -66,6 +66,7 @@ struct GraphExecutorImplBase {

// entry point where execution begins
void run(Stack& stack);
c10::intrusive_ptr<Future> runAsync(Stack& stack);

virtual ExecutionPlan getPlanFor(
Stack& stack,
Expand Down
69 changes: 54 additions & 15 deletions torch/testing/_internal/distributed/rpc/jit/rpc_test.py
Expand Up @@ -65,6 +65,36 @@ def test_local_rref_local_value(self):
ret = rref_local_value(rref)
self.assertEqual(ret, 5)

# Define Script functions on both client and server sides.
@torch.jit.script
def no_arg():
return 0

@torch.jit.script
def one_arg(value):
return value + 1

@torch.jit.script
def script_add_ones(x):
return torch.add(x, torch.ones(1))

@torch.jit.script
def script_fork_wait_udf(tensor):
fut = torch.jit._fork(script_add_ones, tensor)
x = torch.jit._wait(fut)
return x

@torch.jit.script
def script_raise_func(value):
if value.numel() == 2:
raise ValueError("Expected error")
return value + 1

@torch.jit.script
def script_fork_wait_throw(invalue):
fut = torch.jit._fork(script_raise_func, invalue)
value = torch.jit._wait(fut)
return value

class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
def __init__(self, dst_worker):
Expand Down Expand Up @@ -233,11 +263,6 @@ def python_function():
return 0


@torch.jit.script
def no_arg():
return 0


@torch.jit.script
def two_args_two_kwargs(
first_arg,
Expand Down Expand Up @@ -575,16 +600,6 @@ def rpc_async_call_remote_nonexisting_torchscript_in_torchscript(
self.assertEqual(ret, 0)


@torch.jit.script
def one_arg(value):
return value + 1

@torch.jit.script
def script_raise_func(value):
if value.numel() == 2:
raise ValueError("Expected error")
return value + 1

@torch.jit.script
def rref_to_here(rref_var):
# type: (RRef[Tensor]) -> Tensor
Expand Down Expand Up @@ -842,3 +857,27 @@ def test_remote_script_throw(self):
args=(torch.ones(2),))
with self.assertRaisesRegex(Exception, ".*Expected error.*"):
rref.to_here()

@dist_init
def test_remote_script_udf(self):
rref = rpc.remote("worker{}".format((self.rank + 1) % self.world_size),
script_fork_wait_udf,
args=(torch.ones(2),))
self.assertEqual(rref.to_here(), torch.ones(2) * 2)

@dist_init
def test_async_script_udf(self):
future = rpc.rpc_async(
"worker{}".format((self.rank + 1) % self.world_size),
script_fork_wait_udf,
args=(torch.ones(2),))
self.assertEqual(future.wait(), torch.ones(2) * 2)

@dist_init
def test_async_script_throw(self):
future = rpc.rpc_async(
"worker{}".format((self.rank + 1) % self.world_size),
script_fork_wait_throw,
args=(torch.ones(2),))
with self.assertRaisesRegex(Exception, ".*Expected error.*"):
future.wait()

0 comments on commit 8d64a38

Please sign in to comment.