From 6fc852410374dadff93d8da22955dfc30ebd7419 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 16 May 2024 13:27:23 -0700 Subject: [PATCH] Add symbolic_shape_specialization structured trace (#126450) This is typically the information you want when diagnosing why something overspecialized in dynamic shapes. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126450 Approved by: https://github.com/albanD --- torch/fx/experimental/symbolic_shapes.py | 27 ++++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index be1be24137f88..e310d490b77c9 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4397,6 +4397,9 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No Use this instead of `self.replacements[a] = tgt`. """ + if tgt == self.replacements.get(a, None): + return + # Precondition: a == tgt assert isinstance(a, sympy.Symbol) @@ -4487,14 +4490,24 @@ def issubset(x, y): "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) return - if config.print_specializations and isinstance(tgt, (sympy.Integer, sympy.Float)): - # specializing to a constant, which is likely unexpected + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name() for s in self.var_to_sources[a]], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), + "user_stack": structured.from_traceback(user_tb) if user_tb else None, + } + ) - # NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g., - # when adding a to self.replacements, and again when simplifying an expression containing a. - # Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is, - # it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant. - if a not in self.replacements or tgt != self.replacements[a]: + if config.print_specializations: self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt) self.log.debug("SPECIALIZATION", stack_info=True) log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)