Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -6432,6 +6432,26 @@ def bad(*args, **kwargs):
with mock.patch("torch._dynamo.eval_frame._maybe_set_eval_frame", bad):
fn(torch.ones(3))

@parametrize("fullgraph", [True, False])
def test_skip_frame_recursive_on_empty_graph(self, fullgraph):
def k(x):
return x

def g(x):
return k(x)

def f(x):
return g(x)

# TODO clear this on all tests
torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos.clear()

opt_f = torch.compile(f, backend="eager", fullgraph=fullgraph)
opt_f(torch.randn(3))
self.assertEqual(len(torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos), 1)
opt_f(torch.randn(3))
self.assertEqual(len(torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos), 2)

def test_torchname(self):
def fn(obj):
return torch.typename(obj)
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/_dynamo/eval_frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ from torch._dynamo.types import DynamoCallback, DynamoGuardHook
# For typechecking
SkipCodeRecursiveFlag = NewType("SkipCodeRecursiveFlag", object)
CacheLimitHitFlag = NewType("CacheLimitHitFlag", object)
SkipFrameRecursiveFlag = NewType("SkipFrameRecursiveFlag", object)
# Flag returned by Dynamo tracer to indicate to Dynamo eval frame that we should skip frames recursively.
skip_code_recursive_flag: SkipCodeRecursiveFlag
cache_limit_hit_flag: CacheLimitHitFlag
skip_frame_recursive_flag: SkipFrameRecursiveFlag

def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
Expand Down
62 changes: 40 additions & 22 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
RecompileLimitExceeded,
ShortenTraceback,
SkipCodeRecursiveException,
SkipFrameRecursiveException,
TorchRuntimeError,
UncapturedHigherOrderOpError,
unimplemented,
Expand Down Expand Up @@ -785,6 +786,10 @@ def log_bytecode(
)
if one_graph:
log.debug("No graph captured with one_graph=True")
if isinstance(e, exc.EmptyGraph):
# Signal to Dynamo eval frame to skip the current frame and any recursive calls.
# Future invocations of the code object will still be traced.
raise SkipFrameRecursiveException from e
return None

assert (
Expand Down Expand Up @@ -1040,6 +1045,8 @@ def format_guard_failures() -> str:
UncapturedHigherOrderOpError,
BisectValidationException,
ShortenTraceback,
SkipCodeRecursiveException,
SkipFrameRecursiveException,
),
):
raise
Expand Down Expand Up @@ -1172,13 +1179,7 @@ def __call__(
hooks: Hooks,
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
skip: int = 0,
) -> Optional[
Union[
GuardedCode,
torch._C._dynamo.eval_frame.SkipCodeRecursiveFlag,
torch._C._dynamo.eval_frame.CacheLimitHitFlag,
]
]:
) -> Optional[GuardedCode]:
counters["frames"]["total"] += 1
try:
result = self._inner_convert(
Expand Down Expand Up @@ -1252,15 +1253,6 @@ def __call__(
else:
log.warning(error_msg, exc_info=True)

# If we encounter SkipCodeRecursiveException, return skip_code_recursive_flag
# to signal to Dynamo eval frame to skip the current frame and any recursive calls.
if isinstance(e, SkipCodeRecursiveException):
return torch._C._dynamo.eval_frame.skip_code_recursive_flag
elif isinstance(e, RecompileLimitExceeded):
# signal to Dynamo to run this frame on run-only mode, skipping recursively if
# no valid cache entry is found.
return torch._C._dynamo.eval_frame.cache_limit_hit_flag

return None


Expand Down Expand Up @@ -1334,7 +1326,14 @@ def __call__(
frame: DynamoFrameType,
cache_entry: Optional[CacheEntry],
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
) -> Optional[GuardedCode]:
) -> Optional[
Union[
GuardedCode,
torch._C._dynamo.eval_frame.SkipCodeRecursiveFlag,
torch._C._dynamo.eval_frame.CacheLimitHitFlag,
torch._C._dynamo.eval_frame.SkipFrameRecursiveFlag,
]
]:
assert frame_state is not None

is_skipfile = trace_rules.check(frame.f_code)
Expand Down Expand Up @@ -1395,11 +1394,30 @@ def __call__(
frame, cache_entry, self.hooks, frame_state
)

with compile_lock, _disable_current_modes():
# skip=1: skip this frame
return self._torchdynamo_orig_callable(
frame, cache_entry, self.hooks, frame_state, skip=1
)
try:
with compile_lock, _disable_current_modes():
# skip=1: skip this frame
return self._torchdynamo_orig_callable(
frame, cache_entry, self.hooks, frame_state, skip=1
)
except Exception as e:
# top-level convert_frame exception handler to handle flags to pass back to
# Dynamo eval frame
if isinstance(e, SkipCodeRecursiveException):
# If we encounter SkipCodeRecursiveException, return skip_code_recursive_flag
# to signal to Dynamo eval frame to skip the current frame and any recursive calls.
# Future invocations of the code object will also be skipped recursively
return torch._C._dynamo.eval_frame.skip_code_recursive_flag
elif isinstance(e, RecompileLimitExceeded):
# signal to Dynamo to run this frame on run-only mode, skipping recursively if
# no valid cache entry is found.
return torch._C._dynamo.eval_frame.cache_limit_hit_flag
elif isinstance(e, SkipFrameRecursiveException):
# If we encounter SkipFrameRecursive, return skip_frame_recursive_flag
# to signal to Dynamo eval frame to skip the current frame and any recursive calls.
# Future invocations of the code object will still be traced.
return torch._C._dynamo.eval_frame.skip_frame_recursive_flag
raise


def catch_errors_wrapper(
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class SkipFrame(TorchDynamoException):
pass


class EmptyGraph(SkipFrame):
pass


class TorchRuntimeError(TorchDynamoException):
pass

Expand Down Expand Up @@ -213,6 +217,10 @@ class RecompileLimitExceeded(Unsupported):
pass


class SkipFrameRecursiveException(TorchDynamoException):
pass


class UnsafeScriptObjectError(TorchDynamoException):
pass

Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3061,7 +3061,7 @@ def _return(self, inst):
and not self.symbolic_locals_contain_module_class()
and not self.export
):
raise exc.SkipFrame("because no content in function call")
raise exc.EmptyGraph("because no content in function call")
self.instruction_pointer = None
_step_logger()(
logging.INFO,
Expand Down
23 changes: 23 additions & 0 deletions torch/csrc/dynamo/eval_frame.c
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ static PyObject* dynamo__custom_eval_frame_shim(

static PyObject* skip_code_recursive_flag;
static PyObject* cache_limit_hit_flag;
static PyObject* skip_frame_recursive_flag;
bool is_skip_guard_eval_unsafe = false;

// NOTE: In 3.12+, the frame evaluation function (callee) is responsible for
Expand Down Expand Up @@ -770,6 +771,18 @@ static PyObject* dynamo__custom_eval_frame(
// Re-enable custom behavior
eval_frame_callback_set(callback);
return r;
} else if (result == skip_frame_recursive_flag) {
// Dynamo returned skip_frame_recursive_flag, so we should recursively skip
// frame, but only for this frame, NOT the code object.
// The difference from skip_code_recursive_flag is that when we attempt to
// trace the code object again, skip_code_recursive_flag will cause Dynamo
// to skip tracing, whereas with skip_frame_recursive_flag, we will attempt
// to trace again.
DEBUG_TRACE("skip frame recursive %s", get_frame_name(frame));
PyObject* r = dynamo_eval_frame_default(tstate, frame, throw_flag);
// Re-enable custom behavior
eval_frame_callback_set(callback);
return r;
} else if (result != Py_None) {
DEBUG_TRACE("create cache %s", get_frame_name(frame));

Expand Down Expand Up @@ -1042,5 +1055,15 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
return NULL;
}

skip_frame_recursive_flag = PyObject_New(PyObject, &PyBaseObject_Type);
if (skip_frame_recursive_flag == NULL) {
return NULL;
}
if (PyModule_AddObject(
module, "skip_frame_recursive_flag", skip_frame_recursive_flag) !=
0) {
return NULL;
}

return module;
}
Loading