Skip to content

Commit

Permalink
Fix perf regression caused by #122074 (#126996)
Browse files Browse the repository at this point in the history
The original change was about 9.5% slower than then before #122074 .
This improves it to be only about 1.4% slower.

Also touched up some unrelated nits that the linter complained about.

Fixes #126293

Ran torchbench 3 times on each change. Perf values before (stable), after (fix),
and with #122074 backed out (backout):
```
../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench pyhpc_isoneutral_mixing amp first dynamic cpp
stable:
43.948x
45.754x
44.906x

fix:
47.505x
49.987x
47.493x

backout:
48.243x
48.199x
48.192x

../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench pyhpc_equation_of_state amp first static default
stable:
15.224x
13.286x
15.354x

fix:
16.402x
16.370x
16.183x

backout:
16.554x
16.675x
16.787x

../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench lennard_jones float32 first static default
stable:
1.712x
1.651x
1.640x

fix:
1.804x
1.798x
1.792x

backout:
1.864x
1.824x
1.836x
```

Pull Request resolved: #126996
Approved by: https://github.com/jansel
  • Loading branch information
aorenste authored and pytorchmergebot committed May 24, 2024
1 parent cb6ef68 commit 70dc59c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 48 deletions.
7 changes: 2 additions & 5 deletions torch/_C/_functorch.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ def _grad_decrement_nesting() -> int: ...
def _jvp_increment_nesting() -> int: ...
def _jvp_decrement_nesting() -> int: ...

class _PreserveDynamicLayerStack:
def __init__(self): ...
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...

# Defined in aten/src/ATen/functorch/Interpreter.h
class TransformType(Enum):
Torch: TransformType = ...
Expand Down Expand Up @@ -80,7 +75,9 @@ class CVmapInterpreterPtr:

class DynamicLayer: ...

def get_dynamic_layer_stack_depth() -> int: ...
def get_interpreter_stack() -> list[CInterpreter]: ...
def peek_interpreter_stack() -> CInterpreter: ...
def pop_dynamic_layer_stack() -> DynamicLayer: ...
def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ...
def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
22 changes: 17 additions & 5 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,25 @@ def _fn(*args, **kwargs):

cleanups = [enter() for enter in self.enter_exit_hooks]
prior = set_eval_frame(callback)

# Ensure that if an assertion occurs after graph pushes
# something onto the DynamicLayerStack then we pop it off (the
# constructed graph code isn't guarded with try/finally).
#
# This used to be a context but putting a `with` here is a noticible
# perf regression (#126293)
saved_dynamic_layer_stack_depth = (
torch._C._functorch.get_dynamic_layer_stack_depth()
)

try:
# Ensure that if an assertion occurs after graph pushes
# something onto the DynamicLayerStack then we pop it off (the
# constructed graph code isn't guarded with try/finally).
with torch._C._functorch._PreserveDynamicLayerStack():
return fn(*args, **kwargs)
return fn(*args, **kwargs)
finally:
# Restore the dynamic layer stack depth if necessary.
torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
saved_dynamic_layer_stack_depth
)

set_eval_frame(prior)
for cleanup in cleanups:
cleanup()
Expand Down
69 changes: 31 additions & 38 deletions torch/csrc/functorch/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@

// This file contains functorch's Python bindings.

namespace torch {
namespace functorch {
namespace impl {
namespace torch::functorch::impl {

using namespace at::functorch;

Expand Down Expand Up @@ -378,7 +376,7 @@ static int64_t currentLevel() {
static std::optional<int64_t> maybe_current_level() {
auto maybe_layer = maybeCurrentDynamicLayer();
if (maybe_layer.has_value()) {
int current_level = maybe_layer->layerId();
int64_t current_level = maybe_layer->layerId();
return current_level;
}
return nullopt;
Expand All @@ -405,36 +403,29 @@ static void dump_local_tls() {

namespace {

// An RAII to save and restore the DynamicLayer stack.
struct PreserveDynamicLayerStack {
size_t m_oldDepth;

~PreserveDynamicLayerStack() {
while (at::functorch::getDynamicLayerStack().size() > m_oldDepth) {
const auto& top = at::functorch::getDynamicLayerStack().back();
switch (top.key()) {
case at::functorch::TransformType::Vmap:
_vmap_decrement_nesting();
break;
case at::functorch::TransformType::Grad:
_grad_decrement_nesting();
break;
case at::functorch::TransformType::Jvp:
_jvp_decrement_nesting();
break;
case at::functorch::TransformType::Functionalize:
_func_decrement_nesting();
break;
case at::functorch::TransformType::Torch:
popDynamicLayerAndDeleteMetadata();
break;
}
// Pop the DynamicLayer stack until it's at the given depth.
void popDynamicLayerStackToDepth(size_t depth) {
while (at::functorch::getDynamicLayerStack().size() > depth) {
const auto top = popDynamicLayer();
switch (top.key()) {
case at::functorch::TransformType::Vmap:
_vmap_decrement_nesting();
break;
case at::functorch::TransformType::Grad:
_grad_decrement_nesting();
break;
case at::functorch::TransformType::Jvp:
_jvp_decrement_nesting();
break;
case at::functorch::TransformType::Functionalize:
_func_decrement_nesting();
break;
case at::functorch::TransformType::Torch:
popDynamicLayerAndDeleteMetadata();
break;
}
}

PreserveDynamicLayerStack()
: m_oldDepth(at::functorch::getDynamicLayerStack().size()) {}
};
}

} // anonymous namespace

Expand Down Expand Up @@ -540,6 +531,7 @@ void initFuncTorchBindings(PyObject* module) {
return c10::nullopt;
}
std::vector<Interpreter> result;
result.reserve(stack.size());
for (auto i : stack) {
result.push_back(i.interpreter());
}
Expand All @@ -553,6 +545,12 @@ void initFuncTorchBindings(PyObject* module) {
auto result = stack.back().interpreter();
return result;
});
m.def("get_dynamic_layer_stack_depth", []() -> size_t {
return getDynamicLayerStack().size();
});
m.def(
"pop_dynamic_layer_stack_and_undo_to_depth",
&popDynamicLayerStackToDepth);
m.def("pop_dynamic_layer_stack", &popDynamicLayer);
m.def("push_dynamic_layer_stack", [](DynamicLayer layer) -> int64_t {
return pushDynamicLayer(std::move(layer));
Expand Down Expand Up @@ -598,11 +596,6 @@ void initFuncTorchBindings(PyObject* module) {
.def(
"functionalizeAddBackViews",
&FunctionalizeInterpreterPtr::functionalizeAddBackViews);

torch::impl::py_context_manager<PreserveDynamicLayerStack>(
m, "_PreserveDynamicLayerStack");
}

} // namespace impl
} // namespace functorch
} // namespace torch
} // namespace torch::functorch::impl

0 comments on commit 70dc59c

Please sign in to comment.