Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix soundness bug with unsupported constraints #102897

Closed
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,30 @@ def f(x, y):
):
export(f, example_inputs, constraints)

def test_export_mod_constraints(self):
class BasicDynamiShapeModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.view(x.shape[0] - 1, -1)

m = BasicDynamiShapeModel()
a = torch.randn(3, 4)
constraints = [3 <= dynamic_dim(a, 0), dynamic_dim(a, 1)]
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify"
".*\n.*\\[0\\], which was marked dynamic, must be specialized to 3"
".*\n.*\\[1\\], which was marked dynamic, must be specialized to 4"
),
):
torch._export.export(m, (a,), constraints=constraints)
em = torch._export.export(m, (a,)).add_runtime_assertions()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to rebase.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recently landed a change that makes this private method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll wait for https://www.internalfb.com/diff/D46471787 to land internally before rebasing.

x = torch.randn(3, 5)
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
em(x)



if __name__ == '__main__':
run_tests()
5 changes: 3 additions & 2 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def test_dim_constraints_reduce_congruences_simple(self):
from torch.fx.experimental.symbolic_shapes import DimConstraints

s = Symbol("s", positive=True, integer=True)
dim_constraints = DimConstraints({}, {})
dim_constraints = DimConstraints({}, {}, set())
dim_constraints._congruences[s] = {
(s / 2) % 2,
(s / 2) % 8,
Expand Down Expand Up @@ -933,7 +933,8 @@ def test_dim_constraints_solve_full(self):
s6: [src6, src9, src10],
}
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
dim_constraints = DimConstraints(symbol_to_source, var_to_val)
marked_dynamic = {s0, s1, s5, s6}
dim_constraints = DimConstraints(symbol_to_source, var_to_val, marked_dynamic)
dim_constraints.add_equality(src2, s0)
dim_constraints.add_equality(src3, s0)
dim_constraints.add_equality(src4, s0)
Expand Down
18 changes: 14 additions & 4 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,15 +955,25 @@ def result_capturing_wrapper(*graph_inputs):
assert dim_constraints is not None
dim_constraints.solve()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = dim_constraints.forced_specializations()
if forced_specializations:
msg = (
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify.\n"
f"{forced_specializations}\n{msg}"
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)

# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
Expand Down
83 changes: 67 additions & 16 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,7 @@ class DimConstraints:
Solutions are "static" values or simplified "dynamic" constraints.
"""

def __init__(self, symbol_to_source, var_to_val):
def __init__(self, symbol_to_source, var_to_val, marked_dynamic):
# We try to solve systems of inequalities with 1 free variable.
self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set)
# Among them, we prioritize solving for a free variable that has equalities.
Expand Down Expand Up @@ -1538,6 +1538,9 @@ def __init__(self, symbol_to_source, var_to_val):
# inconsistencies found on substituting with concrete values / static solutions
self._inconsistencies: List[str] = []

# symbols that are marked dynamic
self._marked_dynamic = marked_dynamic

def rewrite_with_congruences(self, s, expr):
"""
Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
Expand Down Expand Up @@ -1686,7 +1689,35 @@ def raise_inconsistencies(self):
self._inconsistencies.clear()
raise ValueError(f"The following inconsistencies were found:\n{msg}")

def solve(self):
def _force_specialization(self, s):
val = self._var_to_val[s]
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
self._substitutions[s] = val

def specialize_divisor_symbols(self):
for expr in self._multivariate_inequalities:
for atom in expr.atoms(FloorDiv, sympy.Mod):
_, divisor = atom.args
for s in divisor.free_symbols:
self._force_specialization(s)

multivariate_inequalities = self._multivariate_inequalities
self._multivariate_inequalities = set()
for expr in multivariate_inequalities:
self.add(expr.subs(self._substitutions))
self.raise_inconsistencies()
self._univariate_inequalities = {
s: exprs
for s, exprs in self._univariate_inequalities.items()
if s not in self._substitutions
}
self._congruences = {
s: congruences
for s, congruences in self._congruences.items()
if s not in self._substitutions
}

def solve(self, disable_congruences=True):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If and when we support specifying % constraints, we can enable this flag.

self.raise_inconsistencies()
# as long as there are symbols with equalities, solve for them
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
Expand All @@ -1711,11 +1742,20 @@ def solve(self):
self.add(expr.subs(s, self._substitutions[s]))
self.raise_inconsistencies()

# simplify symbolic equivalences: some of them will now become specializations!
symbolic_equivalences = self._symbolic_equivalences
self._symbolic_equivalences = []
for source, expr in symbolic_equivalences:
self.add_equality(source, expr.subs(s, self._substitutions[s]))
self.specialize_divisor_symbols()

# solve linear congruences
# NOTE(avik): We do not need to solve them for symbols that have already been specialized.
reduced_congruences = self.reduce_congruences()
for s, congruences in reduced_congruences.items():
for congruence in congruences:
# any congruence that cannot be checked becomes a dynamic constraint as well
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
if disable_congruences:
self._force_specialization(s)
self._univariate_inequalities.pop(s, None)
else:
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(congruence, 0)))

# remaining symbols have only pure inequalities (no equalities)
for s, exprs in self._univariate_inequalities.items():
Expand All @@ -1732,18 +1772,25 @@ def solve(self):
for expr in exprs:
self._dynamic_results.add(self._dcp.doprint(expr))

# simplify symbolic equivalences: some of them will now become specializations!
symbolic_equivalences = self._symbolic_equivalences
self._symbolic_equivalences = []
for source, expr in symbolic_equivalences:
self.add_equality(source, expr.subs(self._substitutions))

# remaining symbolic equivalences become dynamic equality constraints
for source, expr in self._symbolic_equivalences:
self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}")

# solve linear congruences
# NOTE(avik): We do not need to solve them for symbols that have already been specialized.
reduced_congruences = self.reduce_congruences()
for s, congruences in reduced_congruences.items():
for congruence in congruences:
# any congruence that cannot be checked becomes a dynamic constraint as well
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(congruence, 0)))
def forced_specializations(self):
return "\n".join([
(
f"\t{self._dcp.symbol_to_source[s][0].name()}, which was marked dynamic, "
f"must be specialized to {val}."
)
for s, val in self._substitutions.items()
if s in self._marked_dynamic
])

def prettify_results(self, original_signature: inspect.Signature):
# Note: Model inputs are wrapped as LocalSource in dynamo.
Expand Down Expand Up @@ -2366,7 +2413,11 @@ def hint(s):
# if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
# This does a lot of work: it covers duck sizing and equality guards.
exprs = []
self.dim_constraints = DimConstraints(symbol_to_source, self.var_to_val)
self.dim_constraints = DimConstraints(
symbol_to_source,
self.var_to_val,
set(symbol_to_constraints.keys()),
)

if not _simplified:
for source, expr in input_guards:
Expand Down