Skip to content

Commit

Permalink
Use fast traceback for symbolic shapes
Browse files Browse the repository at this point in the history
ghstack-source-id: 983923ae03fa18682e44f650e58d7c8e2caa91da
Pull Request resolved: #107439
  • Loading branch information
ezyang committed Aug 18, 2023
1 parent 3994dc7 commit 7cead35
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 28 deletions.
3 changes: 1 addition & 2 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ class GuardBuilderBase:

class ShapeGuard(NamedTuple):
expr: sympy.Expr
# TODO: store this in slightly less formatted form
stack: str
stack: CapturedTraceback


@dataclasses.dataclass
Expand Down
54 changes: 28 additions & 26 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import operator
import re
import sys
import textwrap
import threading
import traceback
import weakref
Expand Down Expand Up @@ -37,7 +36,7 @@
from torch.utils._sympy.functions import FloorDiv, LShift, Mod, RShift
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._traceback import format_frame
from torch.utils._traceback import format_frame, CapturedTraceback
from torch._utils_internal import signpost_event

InputList = List
Expand Down Expand Up @@ -1078,11 +1077,6 @@ def error():
raise AssertionError("shouldn't be hit")


def get_debugging_stack(num_frames_to_cut=2):
# cut this frame and the caller's frame by default
return ''.join(traceback.format_list(traceback.extract_stack()[:-num_frames_to_cut]))


def floor_ceil_helper(a, fn):
if isinstance(a, sympy.Mul):
aa = a.args
Expand Down Expand Up @@ -2057,7 +2051,7 @@ def __init__(
# for N < 2. Therefore, it will be too strict to assert N=2 at runtime.
self.runtime_var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
self.var_to_stack: Dict[sympy.Symbol, traceback.StackSummary] = {}
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
# Maps symbolic ints to the guards that refine their lower/upper
# bound. If one of them is None, it means that there are no guards
# that refine that respective bound.
Expand Down Expand Up @@ -2408,7 +2402,7 @@ def create_symboolnode(self, sym: "sympy.Expr"):
def create_unbacked_symfloat(self):
symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
self.var_to_range[symbol] = ValueRanges.unknown()

# Create a new FX placeholder and Z3 variable for 'symbol'.
Expand All @@ -2419,7 +2413,7 @@ def create_unbacked_symfloat(self):
def create_unbacked_symint(self):
symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
self.var_to_range[symbol] = self._default_unspecified_value_range()

# Create a new FX placeholder and Z3 variable for 'symbol'.
Expand All @@ -2430,7 +2424,7 @@ def create_unbacked_symint(self):
def create_unbacked_symbool(self):
symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
self.var_to_range[symbol] = ValueRanges(0, 1)

# Create a new FX placeholder and Z3 variable for 'symbol'.
Expand Down Expand Up @@ -2854,7 +2848,7 @@ def issue_guard(guard: ShapeGuard) -> None:
else:
raise AssertionError(f"unrecognized constraint {c}")
except Exception:
self.log.warning("Failing guard allocated at: \n%s", guard.stack)
self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format()))
raise

# First, issue all the non-trivial guards.
Expand Down Expand Up @@ -3040,7 +3034,7 @@ def format_guards(self, verbose=False):
def format_tb(tb):
if not verbose:
return ""
return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
return f"\n Guarded at:\n{''.join(' ' + l for l in tb.format())}"

return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)

Expand Down Expand Up @@ -3215,7 +3209,7 @@ def _make_data_dependent_error(self, expr, unhinted_expr):
# TODO: in a Dynamo context, having user code, and having the
# name of the local, will be much better
for s in expr.free_symbols:
stacktrace = ''.join(traceback.format_list(self.var_to_stack[s]))
stacktrace = ''.join(self.var_to_stack[s].format())
self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace)
return GuardOnDataDependentSymNode(
"It appears that you're trying to get a value out of symbolic int/float "
Expand Down Expand Up @@ -3357,11 +3351,22 @@ def _check_frozen(self, expr, concrete_val):
log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)


def _log_guard(self, prefix: str, g, tb):
def _log_guard(self, prefix: str, g):
if self.log.isEnabledFor(logging.INFO):
for frame in reversed(tb):
if frame.filename not in uninteresting_files():
break
fsummary = None
frame = inspect.currentframe()
try:
while frame is not None:
if frame.f_code.co_filename not in uninteresting_files():
fsummary = traceback.FrameSummary(
frame.f_code.co_filename,
frame.f_lineno,
frame.f_code.co_name,
)
break
frame = frame.f_back
finally:
del frame

# NB: this stack is truncated, but it's fine because the main
# stack_info will give you the rest of the info you need
Expand All @@ -3382,7 +3387,7 @@ def _log_guard(self, prefix: str, g, tb):
"eval %s [guard added]%s (%s)%s",
g,
maybe_user_loc,
format_frame(frame),
format_frame(fsummary),
maybe_extra_debug,
stack_info=is_debug,
)
Expand Down Expand Up @@ -3485,8 +3490,7 @@ def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]

if not self._suppress_guards_tls():
tb = traceback.extract_stack()[:-1]
stack = ''.join(traceback.format_list(tb))
stack = CapturedTraceback.extract(skip=1)
guard = ShapeGuard(g, stack)
self.guards.append(guard)
except Exception:
Expand All @@ -3495,11 +3499,10 @@ def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):
else:
if not self._suppress_guards_tls():
assert guard is not None
assert tb is not None

self.refine_ranges(guard)

self._log_guard("eval", g, tb)
self._log_guard("eval", g)
else:
self.log.debug("eval %s [guard suppressed]", g)

Expand Down Expand Up @@ -3536,8 +3539,7 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
# here)

if not self._suppress_guards_tls():
tb = traceback.extract_stack()[:-1]
stack = ''.join(traceback.format_list(tb))
stack = CapturedTraceback.extract(skip=1)
ra = RuntimeAssert(expr, msg, stack)
# TODO: Do this in a way that is less janky than int(s.name[1:])
cands = sorted([s for s in expr.free_symbols if s.name.startswith("i")], key=lambda s: int(s.name[1:]))
Expand All @@ -3549,7 +3551,7 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
# in ranges. For example, i0 <= s0 is un-rangeable, because
# we can't put s0 in the range. So this is not very high
# priority at the moment.
self._log_guard("runtime_assert", expr, tb)
self._log_guard("runtime_assert", expr)
else:
self.log.debug("runtime_assert %s [guard suppressed]", expr)

Expand Down

0 comments on commit 7cead35

Please sign in to comment.