Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fast traceback for symbolic shapes #107439

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ class GuardBuilderBase:

class ShapeGuard(NamedTuple):
expr: sympy.Expr
# TODO: store this in slightly less formatted form
stack: str
stack: CapturedTraceback
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we love it



@dataclasses.dataclass
Expand Down Expand Up @@ -694,6 +693,12 @@ def tracing(context: TracingContext):
e.real_stack = context.extract_stack() # type: ignore[attr-defined]
raise
finally:
if (
context is not None
and context.fake_mode is not None
and context.fake_mode.shape_env is not None
):
context.fake_mode.shape_env.cleanup()
_TLS.tracing_context = old_context


Expand Down
66 changes: 40 additions & 26 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import operator
import re
import sys
import textwrap
import threading
import traceback
from collections import defaultdict
Expand Down Expand Up @@ -36,7 +35,7 @@
from torch.utils._sympy.functions import FloorDiv, LShift, Mod, RShift
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._traceback import format_frame
from torch.utils._traceback import format_frame, CapturedTraceback
from torch._utils_internal import signpost_event

InputList = List
Expand Down Expand Up @@ -1059,11 +1058,6 @@ def error():
raise AssertionError("shouldn't be hit")


def get_debugging_stack(num_frames_to_cut=2):
# cut this frame and the caller's frame by default
return ''.join(traceback.format_list(traceback.extract_stack()[:-num_frames_to_cut]))


def floor_ceil_helper(a, fn):
if isinstance(a, sympy.Mul):
aa = a.args
Expand Down Expand Up @@ -2028,7 +2022,7 @@ def __init__(
# for N < 2. Therefore, it will be too strict to assert N=2 at runtime.
self.runtime_var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
self.var_to_stack: Dict[sympy.Symbol, traceback.StackSummary] = {}
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
# Maps symbolic ints to the guards that refine their lower/upper
# bound. If one of them is None, it means that there are no guards
# that refine that respective bound.
Expand Down Expand Up @@ -2373,7 +2367,7 @@ def create_symboolnode(self, sym: "sympy.Expr"):
def create_unbacked_symfloat(self):
symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
self.var_to_range[symbol] = ValueRanges.unknown()

# Create a new FX placeholder and Z3 variable for 'symbol'.
Expand All @@ -2384,7 +2378,7 @@ def create_unbacked_symfloat(self):
def create_unbacked_symint(self):
symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
self.var_to_range[symbol] = self._default_unspecified_value_range()

# Create a new FX placeholder and Z3 variable for 'symbol'.
Expand All @@ -2395,7 +2389,7 @@ def create_unbacked_symint(self):
def create_unbacked_symbool(self):
symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
self.var_to_range[symbol] = ValueRanges(0, 1)

# Create a new FX placeholder and Z3 variable for 'symbol'.
Expand Down Expand Up @@ -2819,7 +2813,7 @@ def issue_guard(guard: ShapeGuard) -> None:
else:
raise AssertionError(f"unrecognized constraint {c}")
except Exception:
self.log.warning("Failing guard allocated at: \n%s", guard.stack)
self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format()))
raise

# First, issue all the non-trivial guards.
Expand Down Expand Up @@ -3005,7 +2999,7 @@ def format_guards(self, verbose=False):
def format_tb(tb):
if not verbose:
return ""
return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
return f"\n Guarded at:\n{''.join(' ' + l for l in tb.format())}"

return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)

Expand Down Expand Up @@ -3180,7 +3174,7 @@ def _make_data_dependent_error(self, expr, unhinted_expr):
# TODO: in a Dynamo context, having user code, and having the
# name of the local, will be much better
for s in expr.free_symbols:
stacktrace = ''.join(traceback.format_list(self.var_to_stack[s]))
stacktrace = ''.join(self.var_to_stack[s].format())
self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace)
return GuardOnDataDependentSymNode(
"It appears that you're trying to get a value out of symbolic int/float "
Expand Down Expand Up @@ -3322,11 +3316,22 @@ def _check_frozen(self, expr, concrete_val):
log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)


def _log_guard(self, prefix: str, g, tb):
def _log_guard(self, prefix: str, g):
if self.log.isEnabledFor(logging.INFO):
for frame in reversed(tb):
if frame.filename not in uninteresting_files():
break
fsummary = None
frame = inspect.currentframe()
try:
while frame is not None:
if frame.f_code.co_filename not in uninteresting_files():
fsummary = traceback.FrameSummary(
frame.f_code.co_filename,
frame.f_lineno,
frame.f_code.co_name,
)
break
frame = frame.f_back
finally:
del frame

# NB: this stack is truncated, but it's fine because the main
# stack_info will give you the rest of the info you need
Expand All @@ -3347,7 +3352,7 @@ def _log_guard(self, prefix: str, g, tb):
"eval %s [guard added]%s (%s)%s",
g,
maybe_user_loc,
format_frame(frame),
format_frame(fsummary),
maybe_extra_debug,
stack_info=is_debug,
)
Expand Down Expand Up @@ -3450,8 +3455,7 @@ def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]

if not self._suppress_guards_tls():
tb = traceback.extract_stack()[:-1]
stack = ''.join(traceback.format_list(tb))
stack = CapturedTraceback.extract(skip=1)
guard = ShapeGuard(g, stack)
self.guards.append(guard)
except Exception:
Expand All @@ -3460,16 +3464,27 @@ def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):
else:
if not self._suppress_guards_tls():
assert guard is not None
assert tb is not None

self.refine_ranges(guard)

self._log_guard("eval", g, tb)
self._log_guard("eval", g)
else:
self.log.debug("eval %s [guard suppressed]", g)

return concrete_val

def cleanup(self):
# Break reference cycles.
# This destroys the stacks. If you really want to keep them, we
# just need some way to break references on code objects.
for g in self.guards:
g.stack.cleanup()
for s in self.var_to_stack.values():
s.cleanup()
for ras in self.deferred_runtime_asserts.values():
for ra in ras:
ra.stack.cleanup()

def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
expr = orig_expr

Expand Down Expand Up @@ -3501,8 +3516,7 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
# here)

if not self._suppress_guards_tls():
tb = traceback.extract_stack()[:-1]
stack = ''.join(traceback.format_list(tb))
stack = CapturedTraceback.extract(skip=1)
ra = RuntimeAssert(expr, msg, stack)
# TODO: Do this in a way that is less janky than int(s.name[1:])
cands = sorted([s for s in expr.free_symbols if s.name.startswith("i")], key=lambda s: int(s.name[1:]))
Expand All @@ -3514,7 +3528,7 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
# in ranges. For example, i0 <= s0 is un-rangeable, because
# we can't put s0 in the range. So this is not very high
# priority at the moment.
self._log_guard("runtime_assert", expr, tb)
self._log_guard("runtime_assert", expr)
else:
self.log.debug("runtime_assert %s [guard suppressed]", expr)

Expand Down