Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

work around precision error in constraint solver #101607

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 0 additions & 15 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,21 +877,6 @@ def test_dim_constraints_reduce_inequalities_simple(self):
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, {8})

def test_precision(self):
from sympy import Eq, Ne, Symbol
from torch.fx.experimental.symbolic_shapes import DimConstraints

x = Symbol("x", positive=True, integer=True)
y = Symbol("y", positive=True, integer=True)
var_to_val = {x: 296, y: 1155}

dim_constraints = DimConstraints({}, var_to_val)
dim_constraints.add(Eq(x / y, 0.256277056277056))
with self.assertRaisesRegex(AssertionError, "Ne\\(x/y, 296/1155\\) is inconsistent!"):
dim_constraints.add(Ne(x / y, 0.256277056277056))
dim_constraints.solve()
self.assertEqual(dim_constraints._dynamic_results, set())

def test_dim_constraints_solve_full(self):
from sympy import Eq, Integer, Mod, Ne, Symbol
from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource
Expand Down
35 changes: 24 additions & 11 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,9 @@ def __init__(self, symbol_to_source, var_to_val):
# printer for solutions
self._dcp = DynamicDimConstraintPrinter(symbol_to_source)

# inconsistencies found on substituting with concrete values / static solutions
self._inconsistencies: List[str] = []

def rewrite_with_congruences(self, s, expr):
"""
Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
Expand Down Expand Up @@ -1585,20 +1588,18 @@ def floor_div_handler(*args):
return expr

def add(self, expr):
free_symbols = expr.free_symbols
if isinstance(expr, sympy.Rel):
# It is possible that `expr` will fail the consistency check below
# because of precision errors, i.e., on substituting its free symbols
# with their concrete values, we might end up comparing floats. Thus
# we approximate floats with rationals using concrete values as hints.
constants = [self._var_to_val[s] for s in free_symbols]
expr = type(expr)(*(sympy.nsimplify(arg, constants) for arg in expr.args))
if expr == sympy.true:
return
# `expr` should be consistent with concrete values
orig_expr = expr
orig_reduced = orig_expr.subs(self._var_to_val)
assert orig_reduced != sympy.false, f"{orig_expr} is inconsistent!"
# TODO(avik): https://github.com/pytorch/pytorch/issues/101093
# It is possible that `expr` will fail the consistency check because of
# precision errors. Specifically, on substituting its free symbols with
# their concrete values, we might end up comparing floats. Until we have
# a fix for this issue, we delay raising such failures. See solve().
if orig_reduced == sympy.false:
self._inconsistencies.append(f"{orig_expr} is inconsistent!")
free_symbols = expr.free_symbols
assert free_symbols, f"Did not expect constraint with no free variables: {expr}"
if len(free_symbols) > 1:
# multivariate: record and move on
Expand All @@ -1610,7 +1611,11 @@ def add(self, expr):
expr = self.rewrite_with_congruences(s, expr)
if expr != sympy.true:
reduced = expr.subs(self._var_to_val)
assert reduced != sympy.false, f"{expr}, obtained by rewriting {orig_expr} with congruences, is inconsistent!"
if reduced == sympy.false:
self._inconsistencies.append(
f"{expr}, obtained by rewriting {orig_expr} with congruences, "
"is inconsistent!"
)
if isinstance(expr, sympy.Eq):
# special status for symbols that have equalities (see `solve` below)
self._symbols_with_equalities.add(s)
Expand Down Expand Up @@ -1665,7 +1670,14 @@ def reduce_congruences(self):

return reduced_congruences

def raise_inconsistencies(self):
if self._inconsistencies:
msg = "\n".join(self._inconsistencies)
self._inconsistencies.clear()
raise ValueError(f"The following inconsistencies were found:\n{msg}")

def solve(self):
self.raise_inconsistencies()
# as long as there are symbols with equalities, solve for them
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
while(self._symbols_with_equalities):
Expand All @@ -1687,6 +1699,7 @@ def solve(self):
self._multivariate_inequalities = set()
for expr in multivariate_inequalities:
self.add(expr.subs(s, self._substitutions[s]))
self.raise_inconsistencies()

# simplify symbolic equivalences: some of them will now become specializations!
symbolic_equivalences = self._symbolic_equivalences
Expand Down