diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index ef5ab302ea34d..111113221a0c3 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -33,11 +33,6 @@ def _grad_decrement_nesting() -> int: ... def _jvp_increment_nesting() -> int: ... def _jvp_decrement_nesting() -> int: ... -class _PreserveDynamicLayerStack: - def __init__(self): ... - def __enter__(self): ... - def __exit__(self, exc_type, exc_value, traceback): ... - # Defined in aten/src/ATen/functorch/Interpreter.h class TransformType(Enum): Torch: TransformType = ... @@ -80,7 +75,9 @@ class CVmapInterpreterPtr: class DynamicLayer: ... +def get_dynamic_layer_stack_depth() -> int: ... def get_interpreter_stack() -> list[CInterpreter]: ... def peek_interpreter_stack() -> CInterpreter: ... def pop_dynamic_layer_stack() -> DynamicLayer: ... +def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ... def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ... diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index db35c0f631e8c..d9cb59a58f9db 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -406,13 +406,25 @@ def _fn(*args, **kwargs): cleanups = [enter() for enter in self.enter_exit_hooks] prior = set_eval_frame(callback) + + # Ensure that if an assertion occurs after graph pushes + # something onto the DynamicLayerStack then we pop it off (the + # constructed graph code isn't guarded with try/finally). + # + # This used to be a context but putting a `with` here is a noticible + # perf regression (#126293) + saved_dynamic_layer_stack_depth = ( + torch._C._functorch.get_dynamic_layer_stack_depth() + ) + try: - # Ensure that if an assertion occurs after graph pushes - # something onto the DynamicLayerStack then we pop it off (the - # constructed graph code isn't guarded with try/finally). - with torch._C._functorch._PreserveDynamicLayerStack(): - return fn(*args, **kwargs) + return fn(*args, **kwargs) finally: + # Restore the dynamic layer stack depth if necessary. + torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( + saved_dynamic_layer_stack_depth + ) + set_eval_frame(prior) for cleanup in cleanups: cleanup() diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index 6bce80ad27766..53da5a634746c 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -23,9 +23,7 @@ // This file contains functorch's Python bindings. -namespace torch { -namespace functorch { -namespace impl { +namespace torch::functorch::impl { using namespace at::functorch; @@ -378,7 +376,7 @@ static int64_t currentLevel() { static std::optional maybe_current_level() { auto maybe_layer = maybeCurrentDynamicLayer(); if (maybe_layer.has_value()) { - int current_level = maybe_layer->layerId(); + int64_t current_level = maybe_layer->layerId(); return current_level; } return nullopt; @@ -405,36 +403,29 @@ static void dump_local_tls() { namespace { -// An RAII to save and restore the DynamicLayer stack. -struct PreserveDynamicLayerStack { - size_t m_oldDepth; - - ~PreserveDynamicLayerStack() { - while (at::functorch::getDynamicLayerStack().size() > m_oldDepth) { - const auto& top = at::functorch::getDynamicLayerStack().back(); - switch (top.key()) { - case at::functorch::TransformType::Vmap: - _vmap_decrement_nesting(); - break; - case at::functorch::TransformType::Grad: - _grad_decrement_nesting(); - break; - case at::functorch::TransformType::Jvp: - _jvp_decrement_nesting(); - break; - case at::functorch::TransformType::Functionalize: - _func_decrement_nesting(); - break; - case at::functorch::TransformType::Torch: - popDynamicLayerAndDeleteMetadata(); - break; - } +// Pop the DynamicLayer stack until it's at the given depth. +void popDynamicLayerStackToDepth(size_t depth) { + while (at::functorch::getDynamicLayerStack().size() > depth) { + const auto top = popDynamicLayer(); + switch (top.key()) { + case at::functorch::TransformType::Vmap: + _vmap_decrement_nesting(); + break; + case at::functorch::TransformType::Grad: + _grad_decrement_nesting(); + break; + case at::functorch::TransformType::Jvp: + _jvp_decrement_nesting(); + break; + case at::functorch::TransformType::Functionalize: + _func_decrement_nesting(); + break; + case at::functorch::TransformType::Torch: + popDynamicLayerAndDeleteMetadata(); + break; } } - - PreserveDynamicLayerStack() - : m_oldDepth(at::functorch::getDynamicLayerStack().size()) {} -}; +} } // anonymous namespace @@ -540,6 +531,7 @@ void initFuncTorchBindings(PyObject* module) { return c10::nullopt; } std::vector result; + result.reserve(stack.size()); for (auto i : stack) { result.push_back(i.interpreter()); } @@ -553,6 +545,12 @@ void initFuncTorchBindings(PyObject* module) { auto result = stack.back().interpreter(); return result; }); + m.def("get_dynamic_layer_stack_depth", []() -> size_t { + return getDynamicLayerStack().size(); + }); + m.def( + "pop_dynamic_layer_stack_and_undo_to_depth", + &popDynamicLayerStackToDepth); m.def("pop_dynamic_layer_stack", &popDynamicLayer); m.def("push_dynamic_layer_stack", [](DynamicLayer layer) -> int64_t { return pushDynamicLayer(std::move(layer)); @@ -598,11 +596,6 @@ void initFuncTorchBindings(PyObject* module) { .def( "functionalizeAddBackViews", &FunctionalizeInterpreterPtr::functionalizeAddBackViews); - - torch::impl::py_context_manager( - m, "_PreserveDynamicLayerStack"); } -} // namespace impl -} // namespace functorch -} // namespace torch +} // namespace torch::functorch::impl