Skip to content

Commit

Permalink
Add symbolic_shape_specialization structured trace (#126450)
Browse files Browse the repository at this point in the history
This is typically the information you want when diagnosing why something
overspecialized in dynamic shapes.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #126450
Approved by: https://github.com/albanD
  • Loading branch information
ezyang authored and ZelboK committed May 19, 2024
1 parent 22b4b22 commit 6fc8524
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4397,6 +4397,9 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No
Use this instead of `self.replacements[a] = tgt`.
"""

if tgt == self.replacements.get(a, None):
return

# Precondition: a == tgt
assert isinstance(a, sympy.Symbol)

Expand Down Expand Up @@ -4487,14 +4490,24 @@ def issubset(x, y):
"[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so)
return

if config.print_specializations and isinstance(tgt, (sympy.Integer, sympy.Float)):
# specializing to a constant, which is likely unexpected
if isinstance(tgt, (sympy.Integer, sympy.Float)):
# specializing to a constant, which is likely unexpected (unless
# you specified dynamic=True)

user_tb = TracingContext.extract_stack()
trace_structured(
"symbolic_shape_specialization",
metadata_fn=lambda: {
"symbol": repr(a),
"sources": [s.name() for s in self.var_to_sources[a]],
"value": repr(tgt),
"reason": msg,
"stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()),
"user_stack": structured.from_traceback(user_tb) if user_tb else None,
}
)

# NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g.,
# when adding a to self.replacements, and again when simplifying an expression containing a.
# Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is,
# it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant.
if a not in self.replacements or tgt != self.replacements[a]:
if config.print_specializations:
self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt)
self.log.debug("SPECIALIZATION", stack_info=True)
log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
Expand Down

0 comments on commit 6fc8524

Please sign in to comment.