From 70dc59c55f6ed63d9aef13bab5df0f8431302b18 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Thu, 23 May 2024 11:55:50 -0700 Subject: [PATCH] Fix perf regression caused by #122074 (#126996) The original change was about 9.5% slower than then before #122074 . This improves it to be only about 1.4% slower. Also touched up some unrelated nits that the linter complained about. Fixes #126293 Ran torchbench 3 times on each change. Perf values before (stable), after (fix), and with #122074 backed out (backout): ``` ../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench pyhpc_isoneutral_mixing amp first dynamic cpp stable: 43.948x 45.754x 44.906x fix: 47.505x 49.987x 47.493x backout: 48.243x 48.199x 48.192x ../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench pyhpc_equation_of_state amp first static default stable: 15.224x 13.286x 15.354x fix: 16.402x 16.370x 16.183x backout: 16.554x 16.675x 16.787x ../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench lennard_jones float32 first static default stable: 1.712x 1.651x 1.640x fix: 1.804x 1.798x 1.792x backout: 1.864x 1.824x 1.836x ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126996 Approved by: https://github.com/jansel --- torch/_C/_functorch.pyi | 7 +--- torch/_dynamo/eval_frame.py | 22 ++++++++--- torch/csrc/functorch/init.cpp | 69 ++++++++++++++++------------------- 3 files changed, 50 insertions(+), 48 deletions(-) diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index ef5ab302ea34d..111113221a0c3 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -33,11 +33,6 @@ def _grad_decrement_nesting() -> int: ... def _jvp_increment_nesting() -> int: ... def _jvp_decrement_nesting() -> int: ... -class _PreserveDynamicLayerStack: - def __init__(self): ... - def __enter__(self): ... - def __exit__(self, exc_type, exc_value, traceback): ... - # Defined in aten/src/ATen/functorch/Interpreter.h class TransformType(Enum): Torch: TransformType = ... @@ -80,7 +75,9 @@ class CVmapInterpreterPtr: class DynamicLayer: ... +def get_dynamic_layer_stack_depth() -> int: ... def get_interpreter_stack() -> list[CInterpreter]: ... def peek_interpreter_stack() -> CInterpreter: ... def pop_dynamic_layer_stack() -> DynamicLayer: ... +def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ... def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ... diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index db35c0f631e8c..d9cb59a58f9db 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -406,13 +406,25 @@ def _fn(*args, **kwargs): cleanups = [enter() for enter in self.enter_exit_hooks] prior = set_eval_frame(callback) + + # Ensure that if an assertion occurs after graph pushes + # something onto the DynamicLayerStack then we pop it off (the + # constructed graph code isn't guarded with try/finally). + # + # This used to be a context but putting a `with` here is a noticible + # perf regression (#126293) + saved_dynamic_layer_stack_depth = ( + torch._C._functorch.get_dynamic_layer_stack_depth() + ) + try: - # Ensure that if an assertion occurs after graph pushes - # something onto the DynamicLayerStack then we pop it off (the - # constructed graph code isn't guarded with try/finally). - with torch._C._functorch._PreserveDynamicLayerStack(): - return fn(*args, **kwargs) + return fn(*args, **kwargs) finally: + # Restore the dynamic layer stack depth if necessary. + torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( + saved_dynamic_layer_stack_depth + ) + set_eval_frame(prior) for cleanup in cleanups: cleanup() diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index 6bce80ad27766..53da5a634746c 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -23,9 +23,7 @@ // This file contains functorch's Python bindings. -namespace torch { -namespace functorch { -namespace impl { +namespace torch::functorch::impl { using namespace at::functorch; @@ -378,7 +376,7 @@ static int64_t currentLevel() { static std::optional maybe_current_level() { auto maybe_layer = maybeCurrentDynamicLayer(); if (maybe_layer.has_value()) { - int current_level = maybe_layer->layerId(); + int64_t current_level = maybe_layer->layerId(); return current_level; } return nullopt; @@ -405,36 +403,29 @@ static void dump_local_tls() { namespace { -// An RAII to save and restore the DynamicLayer stack. -struct PreserveDynamicLayerStack { - size_t m_oldDepth; - - ~PreserveDynamicLayerStack() { - while (at::functorch::getDynamicLayerStack().size() > m_oldDepth) { - const auto& top = at::functorch::getDynamicLayerStack().back(); - switch (top.key()) { - case at::functorch::TransformType::Vmap: - _vmap_decrement_nesting(); - break; - case at::functorch::TransformType::Grad: - _grad_decrement_nesting(); - break; - case at::functorch::TransformType::Jvp: - _jvp_decrement_nesting(); - break; - case at::functorch::TransformType::Functionalize: - _func_decrement_nesting(); - break; - case at::functorch::TransformType::Torch: - popDynamicLayerAndDeleteMetadata(); - break; - } +// Pop the DynamicLayer stack until it's at the given depth. +void popDynamicLayerStackToDepth(size_t depth) { + while (at::functorch::getDynamicLayerStack().size() > depth) { + const auto top = popDynamicLayer(); + switch (top.key()) { + case at::functorch::TransformType::Vmap: + _vmap_decrement_nesting(); + break; + case at::functorch::TransformType::Grad: + _grad_decrement_nesting(); + break; + case at::functorch::TransformType::Jvp: + _jvp_decrement_nesting(); + break; + case at::functorch::TransformType::Functionalize: + _func_decrement_nesting(); + break; + case at::functorch::TransformType::Torch: + popDynamicLayerAndDeleteMetadata(); + break; } } - - PreserveDynamicLayerStack() - : m_oldDepth(at::functorch::getDynamicLayerStack().size()) {} -}; +} } // anonymous namespace @@ -540,6 +531,7 @@ void initFuncTorchBindings(PyObject* module) { return c10::nullopt; } std::vector result; + result.reserve(stack.size()); for (auto i : stack) { result.push_back(i.interpreter()); } @@ -553,6 +545,12 @@ void initFuncTorchBindings(PyObject* module) { auto result = stack.back().interpreter(); return result; }); + m.def("get_dynamic_layer_stack_depth", []() -> size_t { + return getDynamicLayerStack().size(); + }); + m.def( + "pop_dynamic_layer_stack_and_undo_to_depth", + &popDynamicLayerStackToDepth); m.def("pop_dynamic_layer_stack", &popDynamicLayer); m.def("push_dynamic_layer_stack", [](DynamicLayer layer) -> int64_t { return pushDynamicLayer(std::move(layer)); @@ -598,11 +596,6 @@ void initFuncTorchBindings(PyObject* module) { .def( "functionalizeAddBackViews", &FunctionalizeInterpreterPtr::functionalizeAddBackViews); - - torch::impl::py_context_manager( - m, "_PreserveDynamicLayerStack"); } -} // namespace impl -} // namespace functorch -} // namespace torch +} // namespace torch::functorch::impl