From cebd728698f4ef239c8e5a6bd4b303820148142c Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Fri, 10 Jan 2020 13:27:11 -0800 Subject: [PATCH 1/4] cap the maximum depth of bailout chains at 1 --- torch/csrc/jit/graph_executor.cpp | 15 ++++++--- torch/csrc/jit/graph_executor.h | 4 ++- torch/csrc/jit/graph_executor_impl.h | 2 +- torch/csrc/jit/interpreter.cpp | 32 ++++++++++++++++--- torch/csrc/jit/interpreter.h | 1 + .../jit/profiling_graph_executor_impl.cpp | 7 ++-- .../csrc/jit/profiling_graph_executor_impl.h | 2 +- 7 files changed, 48 insertions(+), 15 deletions(-) diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 85ac26e504620..7f16a203d9593 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -477,7 +477,8 @@ void GraphExecutorImplBase::run(Stack& stack) { logging::getLogger()->addStatValue( logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0); - ExecutionPlan plan = getPlanFor(stack); + ExecutionPlan plan = + getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()); InterpreterState(plan.code).run(stack); last_executed_optimized_graph = plan.graph; } @@ -494,8 +495,8 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); } - ExecutionPlan getPlanFor(Stack& stack) override { - return getGraphExecutorOptimize() ? getOrCompile(stack) + ExecutionPlan getPlanFor(Stack& stack, size_t num_bailouts) override { + return getGraphExecutorOptimize() ? getOrCompile(stack) : getOrCompileFallback(); } @@ -632,8 +633,12 @@ void GraphExecutor::run(Stack& inputs) { return pImpl->run(inputs); } -ExecutionPlan GraphExecutor::getPlanFor(Stack& inputs) { - return pImpl->getPlanFor(inputs); +size_t GraphExecutor::getDefaultNumBailOuts() { + return getProfilingMode() ? 1 : 0; +} + +ExecutionPlan GraphExecutor::getPlanFor(Stack& inputs, size_t num_bailouts) { + return pImpl->getPlanFor(inputs, num_bailouts); } std::shared_ptr GraphExecutor::graph() const { diff --git a/torch/csrc/jit/graph_executor.h b/torch/csrc/jit/graph_executor.h index 8c8dd3697253e..2f123e8694bc8 100644 --- a/torch/csrc/jit/graph_executor.h +++ b/torch/csrc/jit/graph_executor.h @@ -42,13 +42,15 @@ struct TORCH_API GraphExecutor { GraphExecutor() = default; GraphExecutor(std::shared_ptr graph); void run(Stack& inputs); - ExecutionPlan getPlanFor(Stack& inputs); + ExecutionPlan getPlanFor(Stack& inputs, size_t num_bailouts); explicit operator bool() const { return pImpl != nullptr; } std::shared_ptr graph() const; GraphExecutorState getDebugState(); + static size_t getDefaultNumBailOuts(); + private: std::shared_ptr pImpl; }; diff --git a/torch/csrc/jit/graph_executor_impl.h b/torch/csrc/jit/graph_executor_impl.h index 1ef8ca959a6c1..a535d75a6d592 100644 --- a/torch/csrc/jit/graph_executor_impl.h +++ b/torch/csrc/jit/graph_executor_impl.h @@ -62,7 +62,7 @@ struct GraphExecutorImplBase { // entry point where execution begins void run(Stack& stack); - virtual ExecutionPlan getPlanFor(Stack& stack) = 0; + virtual ExecutionPlan getPlanFor(Stack& stack, size_t num_bailouts) = 0; virtual GraphExecutorState getDebugState() = 0; virtual ~GraphExecutorImplBase() = default; diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index f38c8e0c340b9..fb4bbf513665b 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -381,9 +381,12 @@ struct CodeImpl { // out-of-line jumps for bailouts that are patched in at the end std::vector bailout_blocks_; std::vector> bailout_functions_; + size_t num_bailouts_; CodeImpl(const std::shared_ptr& graph) - : preprocess_(*graph), current_node_(preprocess_.graph->return_node()) { + : preprocess_(*graph), + current_node_(preprocess_.graph->return_node()), + num_bailouts_(0) { graph_ = preprocess_.graph; n_outputs = graph_->outputs().size(); if (n_outputs == 1) { @@ -431,6 +434,10 @@ struct CodeImpl { } } + void setNumBailOuts(size_t num_bailouts) { + num_bailouts_ = num_bailouts; + } + void truncateInstructions(size_t size) { while(instructions_.size() > size) { instructions_.pop_back(); @@ -936,7 +943,10 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } break; case CALL: { const Code& code = - af.functions[inst.X]->get_executor().getPlanFor(stack).code; + af.functions[inst.X] + ->get_executor() + .getPlanFor(stack, frames.back().function->num_bailouts_) + .code; frames.back().pc = af.pc + 1; enterFrame(code, stack.size() - code.num_inputs()); af = ActiveFrame(frames.back()); @@ -950,7 +960,10 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { .toObject() ->type() ->getMethod(af.constants[inst.X].toStringRef()); - const Code& code = function->get_executor().getPlanFor(stack).code; + const Code& code = + function->get_executor() + .getPlanFor(stack, frames.back().function->num_bailouts_) + .code; frames.back().pc = af.pc + 1; enterFrame(code, stack.size() - inst.N); af = ActiveFrame(frames.back()); @@ -1028,8 +1041,13 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } break; case TAIL_CALL: { af.functions[inst.X]->ensure_defined(); - const Code &code = - af.functions[inst.X]->get_executor().getPlanFor(stack).code; + size_t num_bailouts = frames.back().function->num_bailouts_ > 0 + ? frames.back().function->num_bailouts_ - 1 + : 0; + const Code& code = af.functions[inst.X] + ->get_executor() + .getPlanFor(stack, num_bailouts) + .code; size_t num_inputs = code.num_inputs(); size_t base_pointer = frames.back().base_pointer; TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs); @@ -1149,6 +1167,10 @@ size_t Code::num_outputs() const { return pImpl->n_outputs; } +void Code::setNumBailOuts(size_t num_bailouts) { + return pImpl->setNumBailOuts(num_bailouts); +} + const std::vector& Code::constant_table() const { return pImpl->constant_table(); } diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index e39595cc3404a..899bd9e3ef869 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -47,6 +47,7 @@ struct TORCH_API Code { const std::vector& instructions() const; const std::vector& instructions_source() const; size_t register_size() const; + void setNumBailOuts(size_t num_bailouts); private: std::shared_ptr pImpl; diff --git a/torch/csrc/jit/profiling_graph_executor_impl.cpp b/torch/csrc/jit/profiling_graph_executor_impl.cpp index 9c448f32f4b9b..e545cd756d3d3 100644 --- a/torch/csrc/jit/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/profiling_graph_executor_impl.cpp @@ -126,7 +126,9 @@ ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl( const std::shared_ptr& graph) : GraphExecutorImplBase(graph) {} -ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) { +ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( + Stack& stack, + size_t num_bailouts) { std::lock_guard lock(compile_mutex); GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this); @@ -135,7 +137,7 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) { } // simple executor - if (!getProfilingMode()) { + if (num_bailouts == 0) { auto copy = graph->copy(); runProfilingInsensitiveOptimizations(copy); GRAPH_DUMP("Optimized SimpleExecutor Graph : ", copy); @@ -163,6 +165,7 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) { runProfilingOptimizations(copy); // cache optimized_plan_ = ExecutionPlan(copy); + optimized_plan_->code.setNumBailOuts(num_bailouts); return *optimized_plan_; } diff --git a/torch/csrc/jit/profiling_graph_executor_impl.h b/torch/csrc/jit/profiling_graph_executor_impl.h index 357cffee4d0dd..8db3d17fa3c4f 100644 --- a/torch/csrc/jit/profiling_graph_executor_impl.h +++ b/torch/csrc/jit/profiling_graph_executor_impl.h @@ -7,7 +7,7 @@ namespace jit { struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase { ProfilingGraphExecutorImpl(const std::shared_ptr& graph); - ExecutionPlan getPlanFor(Stack& stack) override; + ExecutionPlan getPlanFor(Stack& stack, size_t num_bailouts) override; GraphExecutorState getDebugState() override; ~ProfilingGraphExecutorImpl() override = default; From 0a4fb6594d6cac1383422e16cb0d57c27112648e Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Fri, 10 Jan 2020 13:45:39 -0800 Subject: [PATCH 2/4] fix callee bailout depth propagation strategy --- torch/csrc/jit/interpreter.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index fb4bbf513665b..6727f80391cfd 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -943,9 +943,16 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } break; case CALL: { const Code& code = + // consider passing `frames.back().function->num_bailouts_` + // into `get_executor().getPlanFor()` + // to propagate caller's depth restrictions onto children + // while this strategy has a potential to reduce the number + // of compilations for too dynamic callers + // we might miss opportunities where a caller is dynamic + // but a callee gets stable arguments af.functions[inst.X] ->get_executor() - .getPlanFor(stack, frames.back().function->num_bailouts_) + .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()) .code; frames.back().pc = af.pc + 1; enterFrame(code, stack.size() - code.num_inputs()); @@ -956,13 +963,21 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // this can be more optimized if necessary, caching parts // of the hashing computation or storing the offset when // the object is turned into an interface + + // consider passing `frames.back().function->num_bailouts_` + // into `get_executor().getPlanFor()` + // to propagate caller's depth restrictions onto children + // while this strategy has a potential to reduce the number + // of compilations for too dynamic callers + // we might miss opportunities where a caller is dynamic + // but a callee gets stable arguments auto function = peek(stack, 0, inst.N) .toObject() ->type() ->getMethod(af.constants[inst.X].toStringRef()); const Code& code = function->get_executor() - .getPlanFor(stack, frames.back().function->num_bailouts_) + .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()) .code; frames.back().pc = af.pc + 1; enterFrame(code, stack.size() - inst.N); From fb7a432806432586c04ff1288384d471a9533ce5 Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Thu, 16 Jan 2020 18:10:40 -0800 Subject: [PATCH 3/4] thread num_bailouts via Code's c-tor --- torch/csrc/jit/graph_executor.h | 12 ++++++++++-- torch/csrc/jit/interpreter.cpp | 15 ++++----------- torch/csrc/jit/interpreter.h | 6 ++++-- torch/csrc/jit/profiling_graph_executor_impl.cpp | 3 +-- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/torch/csrc/jit/graph_executor.h b/torch/csrc/jit/graph_executor.h index 2f123e8694bc8..6b5c3f75997c2 100644 --- a/torch/csrc/jit/graph_executor.h +++ b/torch/csrc/jit/graph_executor.h @@ -16,8 +16,8 @@ struct Code; struct ExecutionPlan { ExecutionPlan() = default; - ExecutionPlan(std::shared_ptr graph) - : code(graph), graph(std::move(graph)) {} + ExecutionPlan(std::shared_ptr graph, size_t num_bailouts = 0) + : code(graph, num_bailouts), graph(std::move(graph)) {} operator bool() const { return static_cast(graph); @@ -42,6 +42,14 @@ struct TORCH_API GraphExecutor { GraphExecutor() = default; GraphExecutor(std::shared_ptr graph); void run(Stack& inputs); + // `num_bailouts` stands for the maximum number of profiled and specialized + // recompilations allowed for the current `GraphExecutor`. if num_bailouts is + // equal to 0, `GraphExecutor` won't perform any profiling and specialization. + // This is also equivalent to the SIMPLE_EXECUTOR mode. if num_bailouts is + // greater than 0, `GraphExecutor` will profile and specialize its input graph + // based on the profiled information whenever a bailout check is + // failed/triggered, a new `GraphExecutor` will be created. This new + // `GraphExecutor`'s num_bailouts will be reduced by 1. ExecutionPlan getPlanFor(Stack& inputs, size_t num_bailouts); explicit operator bool() const { return pImpl != nullptr; diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 6727f80391cfd..47a070e9987c0 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -383,10 +383,10 @@ struct CodeImpl { std::vector> bailout_functions_; size_t num_bailouts_; - CodeImpl(const std::shared_ptr& graph) + CodeImpl(const std::shared_ptr& graph, size_t num_bailouts) : preprocess_(*graph), current_node_(preprocess_.graph->return_node()), - num_bailouts_(0) { + num_bailouts_(num_bailouts) { graph_ = preprocess_.graph; n_outputs = graph_->outputs().size(); if (n_outputs == 1) { @@ -434,10 +434,6 @@ struct CodeImpl { } } - void setNumBailOuts(size_t num_bailouts) { - num_bailouts_ = num_bailouts; - } - void truncateInstructions(size_t size) { while(instructions_.size() > size) { instructions_.pop_back(); @@ -1167,7 +1163,8 @@ std::ostream& operator<<(std::ostream& out, const Code& code) { return out; } -Code::Code(const std::shared_ptr& graph) : pImpl(new CodeImpl(graph)) {} +Code::Code(const std::shared_ptr& graph, size_t num_bailouts) + : pImpl(new CodeImpl(graph, num_bailouts)) {} Code::~Code() = default; const std::vector& Code::grad_executors() { @@ -1182,10 +1179,6 @@ size_t Code::num_outputs() const { return pImpl->n_outputs; } -void Code::setNumBailOuts(size_t num_bailouts) { - return pImpl->setNumBailOuts(num_bailouts); -} - const std::vector& Code::constant_table() const { return pImpl->constant_table(); } diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index 899bd9e3ef869..b9e9d1a3164ce 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -33,7 +33,10 @@ using c10::ivalue::Future; struct TORCH_API Code { Code() : pImpl(nullptr) {} - explicit Code(const std::shared_ptr& graph); + // num_bailouts is irrelevant in a `Code` object unless the `Code` is directly + // created by `GraphExecutor` in which case it's likely to contain + // `prim::BailOut`s to control the maximum depth of bailout chains + explicit Code(const std::shared_ptr& graph, size_t num_bailouts = 0); ~Code(); const std::vector& grad_executors(); @@ -47,7 +50,6 @@ struct TORCH_API Code { const std::vector& instructions() const; const std::vector& instructions_source() const; size_t register_size() const; - void setNumBailOuts(size_t num_bailouts); private: std::shared_ptr pImpl; diff --git a/torch/csrc/jit/profiling_graph_executor_impl.cpp b/torch/csrc/jit/profiling_graph_executor_impl.cpp index e545cd756d3d3..971c7dda03f07 100644 --- a/torch/csrc/jit/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/profiling_graph_executor_impl.cpp @@ -164,8 +164,7 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( auto copy = pr_->graph()->copy(); runProfilingOptimizations(copy); // cache - optimized_plan_ = ExecutionPlan(copy); - optimized_plan_->code.setNumBailOuts(num_bailouts); + optimized_plan_ = ExecutionPlan(copy, num_bailouts); return *optimized_plan_; } From 0e2c236adbd92b9819691cd322eb4d51f8fff0c0 Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Thu, 16 Jan 2020 19:20:38 -0800 Subject: [PATCH 4/4] rename num_bailouts to remaining_bailout_depth --- torch/csrc/jit/graph_executor.cpp | 9 ++-- torch/csrc/jit/graph_executor.h | 25 ++++++----- torch/csrc/jit/graph_executor_impl.h | 4 +- torch/csrc/jit/interpreter.cpp | 45 ++++++++++--------- torch/csrc/jit/interpreter.h | 8 ++-- .../jit/profiling_graph_executor_impl.cpp | 6 +-- .../csrc/jit/profiling_graph_executor_impl.h | 3 +- 7 files changed, 56 insertions(+), 44 deletions(-) diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 7f16a203d9593..643c3923f699b 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -495,7 +495,8 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); } - ExecutionPlan getPlanFor(Stack& stack, size_t num_bailouts) override { + ExecutionPlan getPlanFor(Stack& stack, size_t remaining_bailout_depth) + override { return getGraphExecutorOptimize() ? getOrCompile(stack) : getOrCompileFallback(); } @@ -637,8 +638,10 @@ size_t GraphExecutor::getDefaultNumBailOuts() { return getProfilingMode() ? 1 : 0; } -ExecutionPlan GraphExecutor::getPlanFor(Stack& inputs, size_t num_bailouts) { - return pImpl->getPlanFor(inputs, num_bailouts); +ExecutionPlan GraphExecutor::getPlanFor( + Stack& inputs, + size_t remaining_bailout_depth) { + return pImpl->getPlanFor(inputs, remaining_bailout_depth); } std::shared_ptr GraphExecutor::graph() const { diff --git a/torch/csrc/jit/graph_executor.h b/torch/csrc/jit/graph_executor.h index 6b5c3f75997c2..da4d46c50ebf6 100644 --- a/torch/csrc/jit/graph_executor.h +++ b/torch/csrc/jit/graph_executor.h @@ -16,8 +16,10 @@ struct Code; struct ExecutionPlan { ExecutionPlan() = default; - ExecutionPlan(std::shared_ptr graph, size_t num_bailouts = 0) - : code(graph, num_bailouts), graph(std::move(graph)) {} + ExecutionPlan( + std::shared_ptr graph, + size_t remaining_bailout_depth = 0) + : code(graph, remaining_bailout_depth), graph(std::move(graph)) {} operator bool() const { return static_cast(graph); @@ -42,15 +44,16 @@ struct TORCH_API GraphExecutor { GraphExecutor() = default; GraphExecutor(std::shared_ptr graph); void run(Stack& inputs); - // `num_bailouts` stands for the maximum number of profiled and specialized - // recompilations allowed for the current `GraphExecutor`. if num_bailouts is - // equal to 0, `GraphExecutor` won't perform any profiling and specialization. - // This is also equivalent to the SIMPLE_EXECUTOR mode. if num_bailouts is - // greater than 0, `GraphExecutor` will profile and specialize its input graph - // based on the profiled information whenever a bailout check is - // failed/triggered, a new `GraphExecutor` will be created. This new - // `GraphExecutor`'s num_bailouts will be reduced by 1. - ExecutionPlan getPlanFor(Stack& inputs, size_t num_bailouts); + // `remaining_bailout_depth` stands for the maximum number of profiled and + // specialized recompilations allowed for the current `GraphExecutor`. if + // remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any + // profiling and specialization. This is also equivalent to the + // SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0, + // `GraphExecutor` will profile and specialize its input graph based on the + // profiled information whenever a bailout check is failed/triggered, a new + // `GraphExecutor` will be created. This new `GraphExecutor`'s + // remaining_bailout_depth will be reduced by 1. + ExecutionPlan getPlanFor(Stack& inputs, size_t remaining_bailout_depth); explicit operator bool() const { return pImpl != nullptr; } diff --git a/torch/csrc/jit/graph_executor_impl.h b/torch/csrc/jit/graph_executor_impl.h index a535d75a6d592..8033b4486bb64 100644 --- a/torch/csrc/jit/graph_executor_impl.h +++ b/torch/csrc/jit/graph_executor_impl.h @@ -62,7 +62,9 @@ struct GraphExecutorImplBase { // entry point where execution begins void run(Stack& stack); - virtual ExecutionPlan getPlanFor(Stack& stack, size_t num_bailouts) = 0; + virtual ExecutionPlan getPlanFor( + Stack& stack, + size_t remaining_bailout_depth) = 0; virtual GraphExecutorState getDebugState() = 0; virtual ~GraphExecutorImplBase() = default; diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 47a070e9987c0..bc5820bc93004 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -381,12 +381,12 @@ struct CodeImpl { // out-of-line jumps for bailouts that are patched in at the end std::vector bailout_blocks_; std::vector> bailout_functions_; - size_t num_bailouts_; + size_t remaining_bailout_depth_; - CodeImpl(const std::shared_ptr& graph, size_t num_bailouts) + CodeImpl(const std::shared_ptr& graph, size_t remaining_bailout_depth) : preprocess_(*graph), current_node_(preprocess_.graph->return_node()), - num_bailouts_(num_bailouts) { + remaining_bailout_depth_(remaining_bailout_depth) { graph_ = preprocess_.graph; n_outputs = graph_->outputs().size(); if (n_outputs == 1) { @@ -939,13 +939,13 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } break; case CALL: { const Code& code = - // consider passing `frames.back().function->num_bailouts_` - // into `get_executor().getPlanFor()` - // to propagate caller's depth restrictions onto children - // while this strategy has a potential to reduce the number - // of compilations for too dynamic callers - // we might miss opportunities where a caller is dynamic - // but a callee gets stable arguments + // consider passing + // `frames.back().function->remaining_bailout_depth_` into + // `get_executor().getPlanFor()` to propagate caller's depth + // restrictions onto children while this strategy has a + // potential to reduce the number of compilations for too + // dynamic callers we might miss opportunities where a caller is + // dynamic but a callee gets stable arguments af.functions[inst.X] ->get_executor() .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()) @@ -960,13 +960,13 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // of the hashing computation or storing the offset when // the object is turned into an interface - // consider passing `frames.back().function->num_bailouts_` - // into `get_executor().getPlanFor()` - // to propagate caller's depth restrictions onto children - // while this strategy has a potential to reduce the number - // of compilations for too dynamic callers - // we might miss opportunities where a caller is dynamic - // but a callee gets stable arguments + // consider passing + // `frames.back().function->remaining_bailout_depth_` into + // `get_executor().getPlanFor()` to propagate caller's depth + // restrictions onto children while this strategy has a potential to + // reduce the number of compilations for too dynamic callers we + // might miss opportunities where a caller is dynamic but a callee + // gets stable arguments auto function = peek(stack, 0, inst.N) .toObject() ->type() @@ -1052,12 +1052,13 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } break; case TAIL_CALL: { af.functions[inst.X]->ensure_defined(); - size_t num_bailouts = frames.back().function->num_bailouts_ > 0 - ? frames.back().function->num_bailouts_ - 1 + size_t remaining_bailout_depth = + frames.back().function->remaining_bailout_depth_ > 0 + ? frames.back().function->remaining_bailout_depth_ - 1 : 0; const Code& code = af.functions[inst.X] ->get_executor() - .getPlanFor(stack, num_bailouts) + .getPlanFor(stack, remaining_bailout_depth) .code; size_t num_inputs = code.num_inputs(); size_t base_pointer = frames.back().base_pointer; @@ -1163,8 +1164,8 @@ std::ostream& operator<<(std::ostream& out, const Code& code) { return out; } -Code::Code(const std::shared_ptr& graph, size_t num_bailouts) - : pImpl(new CodeImpl(graph, num_bailouts)) {} +Code::Code(const std::shared_ptr& graph, size_t remaining_bailout_depth) + : pImpl(new CodeImpl(graph, remaining_bailout_depth)) {} Code::~Code() = default; const std::vector& Code::grad_executors() { diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index b9e9d1a3164ce..1113d4d7254ed 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -33,10 +33,12 @@ using c10::ivalue::Future; struct TORCH_API Code { Code() : pImpl(nullptr) {} - // num_bailouts is irrelevant in a `Code` object unless the `Code` is directly - // created by `GraphExecutor` in which case it's likely to contain + // remaining_bailout_depth is irrelevant in a `Code` object unless the `Code` + // is directly created by `GraphExecutor` in which case it's likely to contain // `prim::BailOut`s to control the maximum depth of bailout chains - explicit Code(const std::shared_ptr& graph, size_t num_bailouts = 0); + explicit Code( + const std::shared_ptr& graph, + size_t remaining_bailout_depth = 0); ~Code(); const std::vector& grad_executors(); diff --git a/torch/csrc/jit/profiling_graph_executor_impl.cpp b/torch/csrc/jit/profiling_graph_executor_impl.cpp index 971c7dda03f07..5862a9a2a20f4 100644 --- a/torch/csrc/jit/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/profiling_graph_executor_impl.cpp @@ -128,7 +128,7 @@ ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl( ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( Stack& stack, - size_t num_bailouts) { + size_t remaining_bailout_depth) { std::lock_guard lock(compile_mutex); GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this); @@ -137,7 +137,7 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( } // simple executor - if (num_bailouts == 0) { + if (remaining_bailout_depth == 0) { auto copy = graph->copy(); runProfilingInsensitiveOptimizations(copy); GRAPH_DUMP("Optimized SimpleExecutor Graph : ", copy); @@ -164,7 +164,7 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( auto copy = pr_->graph()->copy(); runProfilingOptimizations(copy); // cache - optimized_plan_ = ExecutionPlan(copy, num_bailouts); + optimized_plan_ = ExecutionPlan(copy, remaining_bailout_depth); return *optimized_plan_; } diff --git a/torch/csrc/jit/profiling_graph_executor_impl.h b/torch/csrc/jit/profiling_graph_executor_impl.h index 8db3d17fa3c4f..640dfc14b0d7f 100644 --- a/torch/csrc/jit/profiling_graph_executor_impl.h +++ b/torch/csrc/jit/profiling_graph_executor_impl.h @@ -7,7 +7,8 @@ namespace jit { struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase { ProfilingGraphExecutorImpl(const std::shared_ptr& graph); - ExecutionPlan getPlanFor(Stack& stack, size_t num_bailouts) override; + ExecutionPlan getPlanFor(Stack& stack, size_t remaining_bailout_depth) + override; GraphExecutorState getDebugState() override; ~ProfilingGraphExecutorImpl() override = default;