Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion aten/src/ATen/core/builtin_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ struct BuiltinOpFunction : public Function {
callable_(stack);
}

c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) override {
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher /* not used */) override {
run(stack);
auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type());
res->markCompleted(std::move(stack.front()));
Expand Down
9 changes: 8 additions & 1 deletion aten/src/ATen/core/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ namespace c10 {
struct FunctionSchema;
};

namespace at {
CAFFE2_API void launch(std::function<void()> func);
}

namespace torch {
namespace jit {

Expand All @@ -17,6 +21,7 @@ struct GraphExecutor;
using Stack = std::vector<at::IValue>;
using Kwargs = std::unordered_map<std::string, at::IValue>;
struct RecursiveMethodCallError : public std::exception {};
using TaskLauncher = std::function<void(std::function<void()>)>;

TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);

Expand All @@ -36,7 +41,9 @@ struct TORCH_API Function {

virtual void run(Stack&& stack) = 0;

virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) = 0;
virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch) = 0;

virtual at::IValue operator()(
std::vector<at::IValue> stack,
Expand Down
38 changes: 38 additions & 0 deletions test/cpp/jit/test_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/runtime/graph_executor.h"
#include "torch/jit.h"
#include "torch/script.h"
#include "torch/torch.h"

namespace torch {
namespace jit {
Expand Down Expand Up @@ -29,5 +32,40 @@ TEST(GraphExecutorTest, Basic_CUDA) {
ASSERT_TRUE(almostEqual(stack[1].toTensor(), r1));
}

TEST(GraphExecutorTest, runAsync_executor) {
/*
TODO: there are some problem with C++ parsing script program involving
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an issue on GIthub to track this issue? If so, we could add it in the comment.
If not, we could open one for tracking purpose.

fork. Use the test module below for now.
issue about this: github.com/pytorch/pytorch/issues/46368
The test module file is generated by following:
class DemoModule(torch.nn.Module):
def forward(self):
r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
return r1.wait() + r2.wait()
demo = DemoModule()
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth')
*/
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
testModelFile.append("test_interpreter_async.pt");
auto module = load(testModelFile);
auto graph = module.get_method("forward").graph();
GraphExecutor graphExecutor(graph, "");
auto asyncCounter = 0;
std::mutex mtx;
// a dummy executor which actually use at::launch, but add up a counter
auto launcher = [&](std::function<void()> f) {
mtx.lock();
++asyncCounter;
mtx.unlock();
at::launch(move(f));
};
std::vector<IValue> stack;
stack.push_back(module._ivalue());
graphExecutor.runAsync(stack, launcher)->wait();
ASSERT_TRUE(asyncCounter > 0);
}

} // namespace jit
} // namespace torch
40 changes: 40 additions & 0 deletions test/cpp/jit/test_interpreter.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <gtest/gtest.h>

#include <ATen/Parallel.h>
#include "test/cpp/jit/test_utils.h"
#include "torch/jit.h"
#include "torch/script.h"
#include "torch/torch.h"

namespace torch {
namespace jit {
Expand Down Expand Up @@ -138,5 +142,41 @@ TEST(InterpreterTest, Basic_CUDA) {
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
}

TEST(InterpreterTest, runAsyncBasicTest) {
/*
TODO: there are some problem with C++ parsing script program involving
fork. Use the test module below for now.
issue about this: github.com/pytorch/pytorch/issues/46368
The test module file is generated by following:
class DemoModule(torch.nn.Module):
def forward(self):
r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
return r1.wait() + r2.wait()
demo = DemoModule()
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth')
*/
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
testModelFile.append("test_interpreter_async.pt");
auto model = load(testModelFile);
auto graph = model.get_method("forward").graph();
Code function(graph, "");
auto asyncCounter = 0;
std::mutex mtx;
// a dummy executor which actually use at::launch, but add up a counter
auto launcher = [&](std::function<void()> f) {
mtx.lock();
++asyncCounter;
mtx.unlock();
at::launch(f);
};
std::vector<IValue> stack;
stack.push_back(model._ivalue());
InterpreterState interp(function, launcher);
interp.runAsync(stack)->wait();
ASSERT_TRUE(asyncCounter > 0);
}
} // namespace jit
} // namespace torch
Binary file added test/cpp/jit/test_interpreter_async.pt
Binary file not shown.
6 changes: 4 additions & 2 deletions torch/csrc/jit/api/function_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ void GraphFunction::run(Stack&& stack) {
run(stack);
}

c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(Stack& stack) {
return get_executor().runAsync(stack);
c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(
Stack& stack,
TaskLauncher taskLauncher) {
return get_executor().runAsync(stack, std::move(taskLauncher));
}

IValue GraphFunction::operator()(
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/api/function_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ struct TORCH_API GraphFunction : public Function {

void run(Stack&& stack) override;

c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) override;
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch) override;

IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs())
override;
Expand Down
17 changes: 11 additions & 6 deletions torch/csrc/jit/runtime/graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,9 @@ void GraphExecutorImplBase::run(Stack& stack) {
last_executed_optimized_graph = plan.graph;
}

c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(Stack& stack) {
c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(
Stack& stack,
TaskLauncher taskLauncher) {
TORCH_CHECK(
stack.size() >= num_inputs,
"expected ",
Expand All @@ -529,13 +531,14 @@ c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(Stack& stack) {
logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);

struct Frame {
explicit Frame(ExecutionPlan eplan)
: plan(std::move(eplan)), state(plan.code) {}
explicit Frame(ExecutionPlan eplan, TaskLauncher taskLauncher)
: plan(std::move(eplan)), state(plan.code, std::move(taskLauncher)) {}
ExecutionPlan plan;
InterpreterState state;
};
auto frame = std::make_shared<Frame>(
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()));
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()),
std::move(taskLauncher));
auto res = frame->state.runAsync(stack);
last_executed_optimized_graph = frame->plan.graph;
if (!res->completed()) {
Expand Down Expand Up @@ -731,8 +734,10 @@ void GraphExecutor::run(Stack& inputs) {
return pImpl->run(inputs);
}

c10::intrusive_ptr<Future> GraphExecutor::runAsync(Stack& stack) {
return pImpl->runAsync(stack);
c10::intrusive_ptr<Future> GraphExecutor::runAsync(
Stack& stack,
TaskLauncher taskLauncher) {
return pImpl->runAsync(stack, std::move(taskLauncher));
}

size_t GraphExecutor::getDefaultNumBailOuts() {
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/runtime/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ struct TORCH_API GraphExecutor {
GraphExecutor(std::shared_ptr<Graph> graph, std::string function_name);

void run(Stack& inputs);
c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch);

// `remaining_bailout_depth` stands for the maximum number of profiled and
// specialized recompilations allowed for the current `GraphExecutor`. if
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/runtime/graph_executor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ struct GraphExecutorImplBase {

// entry point where execution begins
void run(Stack& stack);
c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch);

virtual ExecutionPlan getPlanFor(
Stack& stack,
Expand Down
23 changes: 16 additions & 7 deletions torch/csrc/jit/runtime/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,8 @@ struct CodeImpl {

// InterpreterState state that and used to compute a Code
struct InterpreterStateImpl : c10::intrusive_ptr_target {
InterpreterStateImpl(const Code& code) {
InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher)
: taskLauncher_(std::move(taskLauncher)) {
enterFrame(code, 0);
}

Expand All @@ -1057,6 +1058,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
// including any inputs to this function
int64_t stack_start_ = -1;
c10::intrusive_ptr<Future> future_;
TaskLauncher taskLauncher_;

// this holds all the tensors for this interpreter run
// we don't bother minimizing the size of this vector, since the extra
Expand Down Expand Up @@ -1335,18 +1337,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
Callback(
c10::intrusive_ptr<InterpreterStateImpl> state,
Stack stack)
: state_(std::move(state)), stack_(std::move(stack)) {
: stateImpl_(std::move(state)),
state_(stateImpl_),
stack_(std::move(stack)) {
dist_autograd_context_id_ = getDistAutogradContextId();
state_ = InterpreterState(stateImpl_);
}
void operator()() {
at::launch(InterpreterContinuation(
stateImpl_->taskLauncher_(InterpreterContinuation(
state_,
std::move(stack_),
dist_autograd_context_id_,
std::move(tls_state_)));
}

private:
c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
InterpreterState state_;
Stack stack_;
int64_t dist_autograd_context_id_;
Expand Down Expand Up @@ -1511,14 +1517,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
InterpreterState forked_interpreter(
forked_fn->get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code);
.code,
taskLauncher_);
InterpreterContinuation continuation(
forked_interpreter,
Stack(stack.end() - inst.N, stack.end()),
getDistAutogradContextId());
drop(stack, inst.N);
push(stack, forked_interpreter.getFuture());
at::launch(std::move(continuation));
taskLauncher_(std::move(continuation));
++frame.pc;
} break;
case WARN: {
Expand Down Expand Up @@ -1740,8 +1747,10 @@ size_t Code::register_size() const {
return pImpl->register_size_;
}

InterpreterState::InterpreterState(const Code& code)
: pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher)
: pImpl(c10::make_intrusive<InterpreterStateImpl>(
code,
std::move(taskLauncher))) {}
InterpreterState::~InterpreterState() = default;

void InterpreterState::run(Stack& stack) {
Expand Down
10 changes: 7 additions & 3 deletions torch/csrc/jit/runtime/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

namespace at {
class Tensor;
}
CAFFE2_API void launch(std::function<void()> func);
} // namespace at
namespace c10 {
struct IValue;
struct OperatorName;
Expand All @@ -32,6 +33,7 @@ struct Node;
struct Instruction;
using Stack = std::vector<c10::IValue>;
using c10::ivalue::Future;
using TaskLauncher = std::function<void(std::function<void()>)>;

struct TORCH_API Code {
Code() : pImpl(nullptr) {}
Expand Down Expand Up @@ -66,9 +68,11 @@ struct TORCH_API Code {
};

struct InterpreterState {
TORCH_API InterpreterState(const Code& code);
TORCH_API InterpreterState(
const Code& code,
TaskLauncher taskLauncher = at::launch);
TORCH_API void run(Stack& stack);
c10::intrusive_ptr<Future> runAsync(Stack& stack);
TORCH_API c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> getFuture();
TORCH_API ~InterpreterState();

Expand Down