Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix perf regression caused by #122074 #126996

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions torch/_C/_functorch.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ def _grad_decrement_nesting() -> int: ...
def _jvp_increment_nesting() -> int: ...
def _jvp_decrement_nesting() -> int: ...

class _PreserveDynamicLayerStack:
def __init__(self): ...
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...

# Defined in aten/src/ATen/functorch/Interpreter.h
class TransformType(Enum):
Torch: TransformType = ...
Expand Down Expand Up @@ -80,7 +75,9 @@ class CVmapInterpreterPtr:

class DynamicLayer: ...

def get_dynamic_layer_stack_depth() -> int: ...
def get_interpreter_stack() -> list[CInterpreter]: ...
def peek_interpreter_stack() -> CInterpreter: ...
def pop_dynamic_layer_stack() -> DynamicLayer: ...
def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ...
def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
22 changes: 17 additions & 5 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,25 @@ def _fn(*args, **kwargs):

cleanups = [enter() for enter in self.enter_exit_hooks]
prior = set_eval_frame(callback)

# Ensure that if an assertion occurs after graph pushes
# something onto the DynamicLayerStack then we pop it off (the
# constructed graph code isn't guarded with try/finally).
#
# This used to be a context but putting a `with` here is a noticible
# perf regression (#126293)
saved_dynamic_layer_stack_depth = (
torch._C._functorch.get_dynamic_layer_stack_depth()
)

try:
# Ensure that if an assertion occurs after graph pushes
# something onto the DynamicLayerStack then we pop it off (the
# constructed graph code isn't guarded with try/finally).
with torch._C._functorch._PreserveDynamicLayerStack():
return fn(*args, **kwargs)
return fn(*args, **kwargs)
finally:
# Restore the dynamic layer stack depth if necessary.
torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
saved_dynamic_layer_stack_depth
)

set_eval_frame(prior)
for cleanup in cleanups:
cleanup()
Expand Down
69 changes: 31 additions & 38 deletions torch/csrc/functorch/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@

// This file contains functorch's Python bindings.

namespace torch {
namespace functorch {
namespace impl {
namespace torch::functorch::impl {

using namespace at::functorch;

Expand Down Expand Up @@ -378,7 +376,7 @@ static int64_t currentLevel() {
static std::optional<int64_t> maybe_current_level() {
auto maybe_layer = maybeCurrentDynamicLayer();
if (maybe_layer.has_value()) {
int current_level = maybe_layer->layerId();
int64_t current_level = maybe_layer->layerId();
return current_level;
}
return nullopt;
Expand All @@ -405,36 +403,29 @@ static void dump_local_tls() {

namespace {

// An RAII to save and restore the DynamicLayer stack.
struct PreserveDynamicLayerStack {
size_t m_oldDepth;

~PreserveDynamicLayerStack() {
while (at::functorch::getDynamicLayerStack().size() > m_oldDepth) {
const auto& top = at::functorch::getDynamicLayerStack().back();
switch (top.key()) {
case at::functorch::TransformType::Vmap:
_vmap_decrement_nesting();
break;
case at::functorch::TransformType::Grad:
_grad_decrement_nesting();
break;
case at::functorch::TransformType::Jvp:
_jvp_decrement_nesting();
break;
case at::functorch::TransformType::Functionalize:
_func_decrement_nesting();
break;
case at::functorch::TransformType::Torch:
popDynamicLayerAndDeleteMetadata();
break;
}
// Pop the DynamicLayer stack until it's at the given depth.
void popDynamicLayerStackToDepth(size_t depth) {
while (at::functorch::getDynamicLayerStack().size() > depth) {
const auto top = popDynamicLayer();
switch (top.key()) {
case at::functorch::TransformType::Vmap:
_vmap_decrement_nesting();
break;
case at::functorch::TransformType::Grad:
_grad_decrement_nesting();
break;
case at::functorch::TransformType::Jvp:
_jvp_decrement_nesting();
break;
case at::functorch::TransformType::Functionalize:
_func_decrement_nesting();
break;
case at::functorch::TransformType::Torch:
popDynamicLayerAndDeleteMetadata();
break;
}
}

PreserveDynamicLayerStack()
: m_oldDepth(at::functorch::getDynamicLayerStack().size()) {}
};
}

} // anonymous namespace

Expand Down Expand Up @@ -540,6 +531,7 @@ void initFuncTorchBindings(PyObject* module) {
return c10::nullopt;
}
std::vector<Interpreter> result;
result.reserve(stack.size());
for (auto i : stack) {
result.push_back(i.interpreter());
}
Expand All @@ -553,6 +545,12 @@ void initFuncTorchBindings(PyObject* module) {
auto result = stack.back().interpreter();
return result;
});
m.def("get_dynamic_layer_stack_depth", []() -> size_t {
return getDynamicLayerStack().size();
});
m.def(
"pop_dynamic_layer_stack_and_undo_to_depth",
&popDynamicLayerStackToDepth);
m.def("pop_dynamic_layer_stack", &popDynamicLayer);
m.def("push_dynamic_layer_stack", [](DynamicLayer layer) -> int64_t {
return pushDynamicLayer(std::move(layer));
Expand Down Expand Up @@ -598,11 +596,6 @@ void initFuncTorchBindings(PyObject* module) {
.def(
"functionalizeAddBackViews",
&FunctionalizeInterpreterPtr::functionalizeAddBackViews);

torch::impl::py_context_manager<PreserveDynamicLayerStack>(
m, "_PreserveDynamicLayerStack");
}

} // namespace impl
} // namespace functorch
} // namespace torch
} // namespace torch::functorch::impl
Loading