Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] cap the maximum depth of bailout chains at 1 #32073

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 13 additions & 5 deletions torch/csrc/jit/graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,8 @@ void GraphExecutorImplBase::run(Stack& stack) {
logging::getLogger()->addStatValue(
logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);

ExecutionPlan plan = getPlanFor(stack);
ExecutionPlan plan =
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts());
InterpreterState(plan.code).run(stack);
last_executed_optimized_graph = plan.graph;
}
Expand All @@ -494,8 +495,9 @@ struct GraphExecutorImpl : public GraphExecutorImplBase {
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
}

ExecutionPlan getPlanFor(Stack& stack) override {
return getGraphExecutorOptimize() ? getOrCompile(stack)
ExecutionPlan getPlanFor(Stack& stack, size_t remaining_bailout_depth)
override {
return getGraphExecutorOptimize() ? getOrCompile(stack)
: getOrCompileFallback();
}

Expand Down Expand Up @@ -632,8 +634,14 @@ void GraphExecutor::run(Stack& inputs) {
return pImpl->run(inputs);
}

ExecutionPlan GraphExecutor::getPlanFor(Stack& inputs) {
return pImpl->getPlanFor(inputs);
size_t GraphExecutor::getDefaultNumBailOuts() {
return getProfilingMode() ? 1 : 0;
}

ExecutionPlan GraphExecutor::getPlanFor(
Stack& inputs,
size_t remaining_bailout_depth) {
return pImpl->getPlanFor(inputs, remaining_bailout_depth);
}

std::shared_ptr<Graph> GraphExecutor::graph() const {
Expand Down
19 changes: 16 additions & 3 deletions torch/csrc/jit/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ struct Code;

struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(std::shared_ptr<Graph> graph)
: code(graph), graph(std::move(graph)) {}
ExecutionPlan(
std::shared_ptr<Graph> graph,
size_t remaining_bailout_depth = 0)
: code(graph, remaining_bailout_depth), graph(std::move(graph)) {}

operator bool() const {
return static_cast<bool>(graph);
Expand All @@ -42,13 +44,24 @@ struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph);
void run(Stack& inputs);
ExecutionPlan getPlanFor(Stack& inputs);
// `remaining_bailout_depth` stands for the maximum number of profiled and
// specialized recompilations allowed for the current `GraphExecutor`. if
// remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any
// profiling and specialization. This is also equivalent to the
// SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0,
// `GraphExecutor` will profile and specialize its input graph based on the
// profiled information whenever a bailout check is failed/triggered, a new
// `GraphExecutor` will be created. This new `GraphExecutor`'s
// remaining_bailout_depth will be reduced by 1.
ExecutionPlan getPlanFor(Stack& inputs, size_t remaining_bailout_depth);
explicit operator bool() const {
return pImpl != nullptr;
}
std::shared_ptr<Graph> graph() const;
GraphExecutorState getDebugState();

static size_t getDefaultNumBailOuts();

private:
std::shared_ptr<GraphExecutorImplBase> pImpl;
};
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/graph_executor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ struct GraphExecutorImplBase {
// entry point where execution begins
void run(Stack& stack);

virtual ExecutionPlan getPlanFor(Stack& stack) = 0;
virtual ExecutionPlan getPlanFor(
Stack& stack,
size_t remaining_bailout_depth) = 0;
virtual GraphExecutorState getDebugState() = 0;
virtual ~GraphExecutorImplBase() = default;

Expand Down
45 changes: 38 additions & 7 deletions torch/csrc/jit/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,12 @@ struct CodeImpl {
// out-of-line jumps for bailouts that are patched in at the end
std::vector<BailoutBlock> bailout_blocks_;
std::vector<std::unique_ptr<Function>> bailout_functions_;
size_t remaining_bailout_depth_;

CodeImpl(const std::shared_ptr<Graph>& graph)
: preprocess_(*graph), current_node_(preprocess_.graph->return_node()) {
CodeImpl(const std::shared_ptr<Graph>& graph, size_t remaining_bailout_depth)
: preprocess_(*graph),
current_node_(preprocess_.graph->return_node()),
remaining_bailout_depth_(remaining_bailout_depth) {
graph_ = preprocess_.graph;
n_outputs = graph_->outputs().size();
if (n_outputs == 1) {
Expand Down Expand Up @@ -936,7 +939,17 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
} break;
case CALL: {
const Code& code =
af.functions[inst.X]->get_executor().getPlanFor(stack).code;
// consider passing
// `frames.back().function->remaining_bailout_depth_` into
// `get_executor().getPlanFor()` to propagate caller's depth
// restrictions onto children while this strategy has a
// potential to reduce the number of compilations for too
// dynamic callers we might miss opportunities where a caller is
// dynamic but a callee gets stable arguments
af.functions[inst.X]
->get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code;
frames.back().pc = af.pc + 1;
enterFrame(code, stack.size() - code.num_inputs());
af = ActiveFrame(frames.back());
Expand All @@ -946,11 +959,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
// this can be more optimized if necessary, caching parts
// of the hashing computation or storing the offset when
// the object is turned into an interface

// consider passing
// `frames.back().function->remaining_bailout_depth_` into
// `get_executor().getPlanFor()` to propagate caller's depth
// restrictions onto children while this strategy has a potential to
// reduce the number of compilations for too dynamic callers we
// might miss opportunities where a caller is dynamic but a callee
// gets stable arguments
auto function = peek(stack, 0, inst.N)
.toObject()
->type()
->getMethod(af.constants[inst.X].toStringRef());
const Code& code = function->get_executor().getPlanFor(stack).code;
const Code& code =
function->get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code;
frames.back().pc = af.pc + 1;
enterFrame(code, stack.size() - inst.N);
af = ActiveFrame(frames.back());
Expand Down Expand Up @@ -1028,8 +1052,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
} break;
case TAIL_CALL: {
af.functions[inst.X]->ensure_defined();
const Code &code =
af.functions[inst.X]->get_executor().getPlanFor(stack).code;
size_t remaining_bailout_depth =
frames.back().function->remaining_bailout_depth_ > 0
? frames.back().function->remaining_bailout_depth_ - 1
: 0;
const Code& code = af.functions[inst.X]
->get_executor()
.getPlanFor(stack, remaining_bailout_depth)
.code;
size_t num_inputs = code.num_inputs();
size_t base_pointer = frames.back().base_pointer;
TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs);
Expand Down Expand Up @@ -1134,7 +1164,8 @@ std::ostream& operator<<(std::ostream& out, const Code& code) {
return out;
}

Code::Code(const std::shared_ptr<Graph>& graph) : pImpl(new CodeImpl(graph)) {}
Code::Code(const std::shared_ptr<Graph>& graph, size_t remaining_bailout_depth)
: pImpl(new CodeImpl(graph, remaining_bailout_depth)) {}
Code::~Code() = default;

const std::vector<GraphExecutor*>& Code::grad_executors() {
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/jit/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ using c10::ivalue::Future;

struct TORCH_API Code {
Code() : pImpl(nullptr) {}
explicit Code(const std::shared_ptr<Graph>& graph);
// remaining_bailout_depth is irrelevant in a `Code` object unless the `Code`
// is directly created by `GraphExecutor` in which case it's likely to contain
// `prim::BailOut`s to control the maximum depth of bailout chains
explicit Code(
const std::shared_ptr<Graph>& graph,
size_t remaining_bailout_depth = 0);
~Code();

const std::vector<GraphExecutor*>& grad_executors();
Expand Down
8 changes: 5 additions & 3 deletions torch/csrc/jit/profiling_graph_executor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl(
const std::shared_ptr<Graph>& graph)
: GraphExecutorImplBase(graph) {}

ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) {
ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(
Stack& stack,
size_t remaining_bailout_depth) {
std::lock_guard<std::mutex> lock(compile_mutex);
GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this);

Expand All @@ -135,7 +137,7 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) {
}

// simple executor
if (!getProfilingMode()) {
if (remaining_bailout_depth == 0) {
auto copy = graph->copy();
runProfilingInsensitiveOptimizations(copy);
GRAPH_DUMP("Optimized SimpleExecutor Graph : ", copy);
Expand All @@ -162,7 +164,7 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) {
auto copy = pr_->graph()->copy();
runProfilingOptimizations(copy);
// cache
optimized_plan_ = ExecutionPlan(copy);
optimized_plan_ = ExecutionPlan(copy, remaining_bailout_depth);
return *optimized_plan_;
}

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/profiling_graph_executor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ namespace jit {
struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
ProfilingGraphExecutorImpl(const std::shared_ptr<Graph>& graph);

ExecutionPlan getPlanFor(Stack& stack) override;
ExecutionPlan getPlanFor(Stack& stack, size_t remaining_bailout_depth)
override;
GraphExecutorState getDebugState() override;
~ProfilingGraphExecutorImpl() override = default;

Expand Down