Skip to content

Commit

Permalink
Make DimConstraints create actionable message
Browse files Browse the repository at this point in the history
  • Loading branch information
ydwu4 committed Apr 27, 2023
1 parent 9bf2dfb commit 9e0d86a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 21 deletions.
39 changes: 36 additions & 3 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import contextlib
import copy
import itertools
import inspect
import math
import operator
import re

import sympy
import torch
Expand Down Expand Up @@ -1755,21 +1757,52 @@ def test_dim_constraints_solve_full(self):
self.assertEqual(dim_constraints._static_results, {
"L['x3'].size()[0] == 8",
"L['x4'].size()[0] == 8",
"L['x1'].size()[2] = 96",
"L['x1'].size()[2] == 96",
"L['x11'].size()[1] == 1",
"L['x7'].size()[3] == 96",
"L['x12'].size()[2] == 3",
"L['x8'].size()[1] == 22",
"L['x2'].size()[0] == 8",
"L['x5'].size()[1] = 22",
"L['x0'].size()[0] = 8",
"L['x5'].size()[1] == 22",
"L['x0'].size()[0] == 8",
})
self.assertEqual(dim_constraints._dynamic_results, {
"dynamic_dim(L['x10'], 1) == dynamic_dim(L['x6'], 1)",
"2 <= dynamic_dim(L['x6'], 1)",
"dynamic_dim(L['x9'], 1) == dynamic_dim(L['x6'], 1)",
})

def dummy_f(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
pass

action_code = dim_constraints.prettify_results(inspect.signature(dummy_f))
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
print(static_code)
expected_static = '''
def specializations(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
return (x0.size()[0] == 8 and
x1.size()[2] == 96 and
x11.size()[1] == 1 and
x12.size()[2] == 3 and
x2.size()[0] == 8 and
x3.size()[0] == 8 and
x4.size()[0] == 8 and
x5.size()[1] == 22 and
x7.size()[3] == 96 and
x8.size()[1] == 22)
'''
expected_dynamic = '''
def specify_constraints(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
return [
2 <= dynamic_dim(x6, 1),
dynamic_dim(x10, 1) == dynamic_dim(x6, 1),
dynamic_dim(x9, 1) == dynamic_dim(x6, 1),
]
'''

self.assertEqual(static_code, expected_static)
self.assertEqual(dynamic_code, expected_dynamic)



if __name__ == '__main__':
Expand Down
9 changes: 9 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,15 @@ def result_capturing_wrapper(*graph_inputs):
assert out_guards is not None, "Failed to produce guards during tracing"
assert fake_mode is not None

if (shape_env := getattr(fake_mode, "shape_env", None)) is not None:
dim_constraints = shape_env.dim_constraints
assert dim_constraints is not None
dim_constraints.solve()
log.warning(
"Summary of dimension constraints:%s",
dim_constraints.prettify_results(inspect.signature(f)),
)

matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)

flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
Expand Down
46 changes: 28 additions & 18 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import operator
import os
import re
import sys
import textwrap
import threading
Expand Down Expand Up @@ -1635,7 +1636,7 @@ def solve(self):
symbol, val = solution.args
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
# because this is univariate, the solution is a specialization
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} = {val}")
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
# add this as a substitution to simplify other constraints
self._substitutions[s] = val

Expand Down Expand Up @@ -1675,24 +1676,36 @@ def solve(self):
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(congruence, 0)))

def prettify_results(self):
def prettify_results(self, original_signature: inspect.Signature):
# Note: Model inputs are wrapped as LocalSource in dynamo.
# LocalSource.name() wraps the name with L[""]. We use regular
# expression to do the replacement to avoid traversing up
# the source hierarchy manually.
def unwrap_local_source(source_name):
return re.sub(r"L\['(.+?)'\]", r'\1', source_name)

buf = ""
indent = 4 * " "
if self._static_results:
sorted_static_results = [unwrap_local_source(res) for res in sorted(self._static_results)]
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
buf += "\nNOTE: Specializations will happen by default with `assume_static_by_default=True`."
for result in self._static_results:
buf += f"\n\t{result}"
buf += "\n"
buf += f"\n```\ndef specializations{str(original_signature)}:"
buf += f"\n{indent}return (" + f" and\n{indent}".join(sorted_static_results) + ")"
buf += "\n```\n"
if self._dynamic_results:
sorted_dynamic_results = sorted(self._dynamic_results)
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nYou can use the following code to specify the constraints they must satisfy:"
buf += "\n```\nconstraints=["
for result in self._dynamic_results:
buf += f"\n\t{result},"
buf += "\n]\n```"
buf += f"\n```\ndef specify_constraints{str(original_signature)}:"
buf += f"\n{indent}return ["
for result in sorted_dynamic_results:
buf += f"\n{indent*2}{unwrap_local_source(result)},"
buf += f"\n{indent}]\n```\n"
return buf



TLS = threading.local()


Expand Down Expand Up @@ -1770,6 +1783,7 @@ def __init__(
self.log = ShapeEnvLoggerAdapter(log, {'envid': env_id})
self.log.info("create_env")
self.frozen = False
self.dim_constraints: Optional[DimConstraints] = None

def freeze(self):
self.frozen = True
Expand Down Expand Up @@ -2200,7 +2214,7 @@ def hint():
# if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
# This does a lot of work: it covers duck sizing and equality guards.
exprs = []
dim_constraints = DimConstraints(symbol_to_source, self.var_to_val)
self.dim_constraints = DimConstraints(symbol_to_source, self.var_to_val)

if not _simplified:
for source, expr in input_guards:
Expand All @@ -2221,7 +2235,7 @@ def hint():
continue

if is_dim(source):
dim_constraints.add_equality(source, expr)
self.dim_constraints.add_equality(source, expr)

sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
exprs.append(f"{source_ref(source)} == {sexpr}")
Expand All @@ -2239,7 +2253,7 @@ def hint():
g = self.simplify(g)
try:
if any(is_dim(source) for s in g.free_symbols for source in symbol_to_source[s]):
dim_constraints.add(g)
self.dim_constraints.add(g)
guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(g)
exprs.append(guard_expr)
# A non-relational constraint on a single sizevar can violate
Expand Down Expand Up @@ -2301,7 +2315,7 @@ def hint():
bounds = []
if r.lower != -sympy.oo:
if any(is_dim(source) for source in sources):
dim_constraints.add(sympy.Ge(symbol, r.lower))
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
bounds.append(str(r.lower))
bounds.append(source_ref(sources[0]))
# NB: This looks like an off-by-one error but it's not: the
Expand All @@ -2313,15 +2327,11 @@ def hint():
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
if any(is_dim(source) for source in sources):
dim_constraints.add(sympy.Le(symbol, r.upper))
self.dim_constraints.add(sympy.Le(symbol, r.upper))
bounds.append(str(r.upper))
if len(bounds) > 1:
exprs.append(" <= ".join(bounds))

if torch._dynamo.config.dynamic_shapes and torch._dynamo.config.summarize_dim_constraints:
dim_constraints.solve()
log.warning("Summary of dimension constraints:%s", dim_constraints.prettify_results())

if constraint_violations:
warn_msgs = []
error_msgs = []
Expand Down

0 comments on commit 9e0d86a

Please sign in to comment.