diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 85ac26e504620..643c3923f699b 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -477,7 +477,8 @@ void GraphExecutorImplBase::run(Stack& stack) { logging::getLogger()->addStatValue( logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0); - ExecutionPlan plan = getPlanFor(stack); + ExecutionPlan plan = + getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()); InterpreterState(plan.code).run(stack); last_executed_optimized_graph = plan.graph; } @@ -494,8 +495,9 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); } - ExecutionPlan getPlanFor(Stack& stack) override { - return getGraphExecutorOptimize() ? getOrCompile(stack) + ExecutionPlan getPlanFor(Stack& stack, size_t remaining_bailout_depth) + override { + return getGraphExecutorOptimize() ? getOrCompile(stack) : getOrCompileFallback(); } @@ -632,8 +634,14 @@ void GraphExecutor::run(Stack& inputs) { return pImpl->run(inputs); } -ExecutionPlan GraphExecutor::getPlanFor(Stack& inputs) { - return pImpl->getPlanFor(inputs); +size_t GraphExecutor::getDefaultNumBailOuts() { + return getProfilingMode() ? 1 : 0; +} + +ExecutionPlan GraphExecutor::getPlanFor( + Stack& inputs, + size_t remaining_bailout_depth) { + return pImpl->getPlanFor(inputs, remaining_bailout_depth); } std::shared_ptr GraphExecutor::graph() const { diff --git a/torch/csrc/jit/graph_executor.h b/torch/csrc/jit/graph_executor.h index 8c8dd3697253e..da4d46c50ebf6 100644 --- a/torch/csrc/jit/graph_executor.h +++ b/torch/csrc/jit/graph_executor.h @@ -16,8 +16,10 @@ struct Code; struct ExecutionPlan { ExecutionPlan() = default; - ExecutionPlan(std::shared_ptr graph) - : code(graph), graph(std::move(graph)) {} + ExecutionPlan( + std::shared_ptr graph, + size_t remaining_bailout_depth = 0) + : code(graph, remaining_bailout_depth), graph(std::move(graph)) {} operator bool() const { return static_cast(graph); @@ -42,13 +44,24 @@ struct TORCH_API GraphExecutor { GraphExecutor() = default; GraphExecutor(std::shared_ptr graph); void run(Stack& inputs); - ExecutionPlan getPlanFor(Stack& inputs); + // `remaining_bailout_depth` stands for the maximum number of profiled and + // specialized recompilations allowed for the current `GraphExecutor`. if + // remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any + // profiling and specialization. This is also equivalent to the + // SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0, + // `GraphExecutor` will profile and specialize its input graph based on the + // profiled information whenever a bailout check is failed/triggered, a new + // `GraphExecutor` will be created. This new `GraphExecutor`'s + // remaining_bailout_depth will be reduced by 1. + ExecutionPlan getPlanFor(Stack& inputs, size_t remaining_bailout_depth); explicit operator bool() const { return pImpl != nullptr; } std::shared_ptr graph() const; GraphExecutorState getDebugState(); + static size_t getDefaultNumBailOuts(); + private: std::shared_ptr pImpl; }; diff --git a/torch/csrc/jit/graph_executor_impl.h b/torch/csrc/jit/graph_executor_impl.h index 1ef8ca959a6c1..8033b4486bb64 100644 --- a/torch/csrc/jit/graph_executor_impl.h +++ b/torch/csrc/jit/graph_executor_impl.h @@ -62,7 +62,9 @@ struct GraphExecutorImplBase { // entry point where execution begins void run(Stack& stack); - virtual ExecutionPlan getPlanFor(Stack& stack) = 0; + virtual ExecutionPlan getPlanFor( + Stack& stack, + size_t remaining_bailout_depth) = 0; virtual GraphExecutorState getDebugState() = 0; virtual ~GraphExecutorImplBase() = default; diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index f38c8e0c340b9..bc5820bc93004 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -381,9 +381,12 @@ struct CodeImpl { // out-of-line jumps for bailouts that are patched in at the end std::vector bailout_blocks_; std::vector> bailout_functions_; + size_t remaining_bailout_depth_; - CodeImpl(const std::shared_ptr& graph) - : preprocess_(*graph), current_node_(preprocess_.graph->return_node()) { + CodeImpl(const std::shared_ptr& graph, size_t remaining_bailout_depth) + : preprocess_(*graph), + current_node_(preprocess_.graph->return_node()), + remaining_bailout_depth_(remaining_bailout_depth) { graph_ = preprocess_.graph; n_outputs = graph_->outputs().size(); if (n_outputs == 1) { @@ -936,7 +939,17 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } break; case CALL: { const Code& code = - af.functions[inst.X]->get_executor().getPlanFor(stack).code; + // consider passing + // `frames.back().function->remaining_bailout_depth_` into + // `get_executor().getPlanFor()` to propagate caller's depth + // restrictions onto children while this strategy has a + // potential to reduce the number of compilations for too + // dynamic callers we might miss opportunities where a caller is + // dynamic but a callee gets stable arguments + af.functions[inst.X] + ->get_executor() + .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()) + .code; frames.back().pc = af.pc + 1; enterFrame(code, stack.size() - code.num_inputs()); af = ActiveFrame(frames.back()); @@ -946,11 +959,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // this can be more optimized if necessary, caching parts // of the hashing computation or storing the offset when // the object is turned into an interface + + // consider passing + // `frames.back().function->remaining_bailout_depth_` into + // `get_executor().getPlanFor()` to propagate caller's depth + // restrictions onto children while this strategy has a potential to + // reduce the number of compilations for too dynamic callers we + // might miss opportunities where a caller is dynamic but a callee + // gets stable arguments auto function = peek(stack, 0, inst.N) .toObject() ->type() ->getMethod(af.constants[inst.X].toStringRef()); - const Code& code = function->get_executor().getPlanFor(stack).code; + const Code& code = + function->get_executor() + .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()) + .code; frames.back().pc = af.pc + 1; enterFrame(code, stack.size() - inst.N); af = ActiveFrame(frames.back()); @@ -1028,8 +1052,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } break; case TAIL_CALL: { af.functions[inst.X]->ensure_defined(); - const Code &code = - af.functions[inst.X]->get_executor().getPlanFor(stack).code; + size_t remaining_bailout_depth = + frames.back().function->remaining_bailout_depth_ > 0 + ? frames.back().function->remaining_bailout_depth_ - 1 + : 0; + const Code& code = af.functions[inst.X] + ->get_executor() + .getPlanFor(stack, remaining_bailout_depth) + .code; size_t num_inputs = code.num_inputs(); size_t base_pointer = frames.back().base_pointer; TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs); @@ -1134,7 +1164,8 @@ std::ostream& operator<<(std::ostream& out, const Code& code) { return out; } -Code::Code(const std::shared_ptr& graph) : pImpl(new CodeImpl(graph)) {} +Code::Code(const std::shared_ptr& graph, size_t remaining_bailout_depth) + : pImpl(new CodeImpl(graph, remaining_bailout_depth)) {} Code::~Code() = default; const std::vector& Code::grad_executors() { diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index e39595cc3404a..1113d4d7254ed 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -33,7 +33,12 @@ using c10::ivalue::Future; struct TORCH_API Code { Code() : pImpl(nullptr) {} - explicit Code(const std::shared_ptr& graph); + // remaining_bailout_depth is irrelevant in a `Code` object unless the `Code` + // is directly created by `GraphExecutor` in which case it's likely to contain + // `prim::BailOut`s to control the maximum depth of bailout chains + explicit Code( + const std::shared_ptr& graph, + size_t remaining_bailout_depth = 0); ~Code(); const std::vector& grad_executors(); diff --git a/torch/csrc/jit/profiling_graph_executor_impl.cpp b/torch/csrc/jit/profiling_graph_executor_impl.cpp index 9c448f32f4b9b..5862a9a2a20f4 100644 --- a/torch/csrc/jit/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/profiling_graph_executor_impl.cpp @@ -126,7 +126,9 @@ ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl( const std::shared_ptr& graph) : GraphExecutorImplBase(graph) {} -ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) { +ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( + Stack& stack, + size_t remaining_bailout_depth) { std::lock_guard lock(compile_mutex); GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this); @@ -135,7 +137,7 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) { } // simple executor - if (!getProfilingMode()) { + if (remaining_bailout_depth == 0) { auto copy = graph->copy(); runProfilingInsensitiveOptimizations(copy); GRAPH_DUMP("Optimized SimpleExecutor Graph : ", copy); @@ -162,7 +164,7 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) { auto copy = pr_->graph()->copy(); runProfilingOptimizations(copy); // cache - optimized_plan_ = ExecutionPlan(copy); + optimized_plan_ = ExecutionPlan(copy, remaining_bailout_depth); return *optimized_plan_; } diff --git a/torch/csrc/jit/profiling_graph_executor_impl.h b/torch/csrc/jit/profiling_graph_executor_impl.h index 357cffee4d0dd..640dfc14b0d7f 100644 --- a/torch/csrc/jit/profiling_graph_executor_impl.h +++ b/torch/csrc/jit/profiling_graph_executor_impl.h @@ -7,7 +7,8 @@ namespace jit { struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase { ProfilingGraphExecutorImpl(const std::shared_ptr& graph); - ExecutionPlan getPlanFor(Stack& stack) override; + ExecutionPlan getPlanFor(Stack& stack, size_t remaining_bailout_depth) + override; GraphExecutorState getDebugState() override; ~ProfilingGraphExecutorImpl() override = default;