Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DimConstraints create actionable message #100103

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 36 additions & 3 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import contextlib
import copy
import itertools
import inspect
import math
import operator
import re

import sympy
import torch
Expand Down Expand Up @@ -1755,21 +1757,52 @@ def test_dim_constraints_solve_full(self):
self.assertEqual(dim_constraints._static_results, {
"L['x3'].size()[0] == 8",
"L['x4'].size()[0] == 8",
"L['x1'].size()[2] = 96",
"L['x1'].size()[2] == 96",
"L['x11'].size()[1] == 1",
"L['x7'].size()[3] == 96",
"L['x12'].size()[2] == 3",
"L['x8'].size()[1] == 22",
"L['x2'].size()[0] == 8",
"L['x5'].size()[1] = 22",
"L['x0'].size()[0] = 8",
"L['x5'].size()[1] == 22",
"L['x0'].size()[0] == 8",
})
self.assertEqual(dim_constraints._dynamic_results, {
"dynamic_dim(L['x10'], 1) == dynamic_dim(L['x6'], 1)",
"2 <= dynamic_dim(L['x6'], 1)",
"dynamic_dim(L['x9'], 1) == dynamic_dim(L['x6'], 1)",
})

def dummy_f(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
pass

action_code = dim_constraints.prettify_results(inspect.signature(dummy_f))
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
print(static_code)
expected_static = '''
def specializations(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
return (x0.size()[0] == 8 and
x1.size()[2] == 96 and
x11.size()[1] == 1 and
x12.size()[2] == 3 and
x2.size()[0] == 8 and
x3.size()[0] == 8 and
x4.size()[0] == 8 and
x5.size()[1] == 22 and
x7.size()[3] == 96 and
x8.size()[1] == 22)
Comment on lines +1782 to +1792
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to inform users that those dims are specialized. Users don't need to paste and use it anywhere explicitly right?

@tugsbayasgalan These are also turned into runtime check correct? If it's not stored in metadata, how do you find this info

'''
expected_dynamic = '''
def specify_constraints(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
return [
2 <= dynamic_dim(x6, 1),
dynamic_dim(x10, 1) == dynamic_dim(x6, 1),
dynamic_dim(x9, 1) == dynamic_dim(x6, 1),
]
'''

self.assertEqual(static_code, expected_static)
self.assertEqual(dynamic_code, expected_dynamic)



if __name__ == '__main__':
Expand Down
9 changes: 9 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,15 @@ def result_capturing_wrapper(*graph_inputs):
assert out_guards is not None, "Failed to produce guards during tracing"
assert fake_mode is not None

if (shape_env := getattr(fake_mode, "shape_env", None)) is not None:
dim_constraints = shape_env.dim_constraints
assert dim_constraints is not None
dim_constraints.solve()
log.warning(
"Summary of dimension constraints:%s",
dim_constraints.prettify_results(inspect.signature(f)),
)

matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)

flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
Expand Down
46 changes: 28 additions & 18 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import operator
import os
import re
import sys
import textwrap
import threading
Expand Down Expand Up @@ -1635,7 +1636,7 @@ def solve(self):
symbol, val = solution.args
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
# because this is univariate, the solution is a specialization
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} = {val}")
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
# add this as a substitution to simplify other constraints
self._substitutions[s] = val

Expand Down Expand Up @@ -1675,24 +1676,36 @@ def solve(self):
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(congruence, 0)))

def prettify_results(self):
def prettify_results(self, original_signature: inspect.Signature):
# Note: Model inputs are wrapped as LocalSource in dynamo.
# LocalSource.name() wraps the name with L[""]. We use regular
# expression to do the replacement to avoid traversing up
# the source hierarchy manually.
def unwrap_local_source(source_name):
return re.sub(r"L\['(.+?)'\]", r'\1', source_name)

buf = ""
indent = 4 * " "
if self._static_results:
sorted_static_results = [unwrap_local_source(res) for res in sorted(self._static_results)]
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
buf += "\nNOTE: Specializations will happen by default with `assume_static_by_default=True`."
for result in self._static_results:
buf += f"\n\t{result}"
buf += "\n"
buf += f"\n```\ndef specializations{str(original_signature)}:"
buf += f"\n{indent}return (" + f" and\n{indent}".join(sorted_static_results) + ")"
buf += "\n```\n"
if self._dynamic_results:
sorted_dynamic_results = sorted(self._dynamic_results)
ydwu4 marked this conversation as resolved.
Show resolved Hide resolved
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nYou can use the following code to specify the constraints they must satisfy:"
buf += "\n```\nconstraints=["
for result in self._dynamic_results:
buf += f"\n\t{result},"
buf += "\n]\n```"
buf += f"\n```\ndef specify_constraints{str(original_signature)}:"
buf += f"\n{indent}return ["
for result in sorted_dynamic_results:
buf += f"\n{indent*2}{unwrap_local_source(result)},"
buf += f"\n{indent}]\n```\n"
return buf



TLS = threading.local()


Expand Down Expand Up @@ -1770,6 +1783,7 @@ def __init__(
self.log = ShapeEnvLoggerAdapter(log, {'envid': env_id})
self.log.info("create_env")
self.frozen = False
self.dim_constraints: Optional[DimConstraints] = None

def freeze(self):
self.frozen = True
Expand Down Expand Up @@ -2200,7 +2214,7 @@ def hint():
# if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
# This does a lot of work: it covers duck sizing and equality guards.
exprs = []
dim_constraints = DimConstraints(symbol_to_source, self.var_to_val)
self.dim_constraints = DimConstraints(symbol_to_source, self.var_to_val)

if not _simplified:
for source, expr in input_guards:
Expand All @@ -2221,7 +2235,7 @@ def hint():
continue

if is_dim(source):
dim_constraints.add_equality(source, expr)
self.dim_constraints.add_equality(source, expr)

sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
exprs.append(f"{source_ref(source)} == {sexpr}")
Expand All @@ -2239,7 +2253,7 @@ def hint():
g = self.simplify(g)
try:
if any(is_dim(source) for s in g.free_symbols for source in symbol_to_source[s]):
dim_constraints.add(g)
self.dim_constraints.add(g)
guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(g)
exprs.append(guard_expr)
# A non-relational constraint on a single sizevar can violate
Expand Down Expand Up @@ -2301,7 +2315,7 @@ def hint():
bounds = []
if r.lower != -sympy.oo:
if any(is_dim(source) for source in sources):
dim_constraints.add(sympy.Ge(symbol, r.lower))
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
bounds.append(str(r.lower))
bounds.append(source_ref(sources[0]))
# NB: This looks like an off-by-one error but it's not: the
Expand All @@ -2313,15 +2327,11 @@ def hint():
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
if any(is_dim(source) for source in sources):
dim_constraints.add(sympy.Le(symbol, r.upper))
self.dim_constraints.add(sympy.Le(symbol, r.upper))
bounds.append(str(r.upper))
if len(bounds) > 1:
exprs.append(" <= ".join(bounds))

if torch._dynamo.config.dynamic_shapes and torch._dynamo.config.summarize_dim_constraints:
dim_constraints.solve()
log.warning("Summary of dimension constraints:%s", dim_constraints.prettify_results())

if constraint_violations:
warn_msgs = []
error_msgs = []
Expand Down