Skip to content

Commit

Permalink
work around precision error in constraint solver
Browse files Browse the repository at this point in the history
In #101307 we tried to fix #101093 using `nsimplify` to convert floats into rationals, but the fix is not reliable: it is possible for `nsimplify` to pick constants that don't work.

Currently, constraint solving is only used by `export`, but constraints are added in all modes. This means that we can hit this issue even in non-`export` modes. This diff works around this issue for such modes by delaying raising such failures until constraint solving.

Differential Revision: [D45922797](https://our.internmc.facebook.com/intern/diff/D45922797/)

ghstack-source-id: 189321831
Pull Request resolved: #101607
  • Loading branch information
avikchaudhuri committed May 16, 2023
1 parent 1f69e48 commit 8d6997a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 26 deletions.
15 changes: 0 additions & 15 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,21 +877,6 @@ def test_dim_constraints_reduce_inequalities_simple(self):
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, {8})

def test_precision(self):
from sympy import Eq, Ne, Symbol
from torch.fx.experimental.symbolic_shapes import DimConstraints

x = Symbol("x", positive=True, integer=True)
y = Symbol("y", positive=True, integer=True)
var_to_val = {x: 296, y: 1155}

dim_constraints = DimConstraints({}, var_to_val)
dim_constraints.add(Eq(x / y, 0.256277056277056))
with self.assertRaisesRegex(AssertionError, "Ne\\(x/y, 296/1155\\) is inconsistent!"):
dim_constraints.add(Ne(x / y, 0.256277056277056))
dim_constraints.solve()
self.assertEqual(dim_constraints._dynamic_results, set())

def test_dim_constraints_solve_full(self):
from sympy import Eq, Integer, Mod, Ne, Symbol
from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource
Expand Down
35 changes: 24 additions & 11 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,9 @@ def __init__(self, symbol_to_source, var_to_val):
# printer for solutions
self._dcp = DynamicDimConstraintPrinter(symbol_to_source)

# inconsistencies found on substituting with concrete values / static solutions
self._inconsistencies: List[str] = []

def rewrite_with_congruences(self, s, expr):
"""
Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
Expand Down Expand Up @@ -1585,20 +1588,18 @@ def floor_div_handler(*args):
return expr

def add(self, expr):
free_symbols = expr.free_symbols
if isinstance(expr, sympy.Rel):
# It is possible that `expr` will fail the consistency check below
# because of precision errors, i.e., on substituting its free symbols
# with their concrete values, we might end up comparing floats. Thus
# we approximate floats with rationals using concrete values as hints.
constants = [self._var_to_val[s] for s in free_symbols]
expr = type(expr)(*(sympy.nsimplify(arg, constants) for arg in expr.args))
if expr == sympy.true:
return
# `expr` should be consistent with concrete values
orig_expr = expr
orig_reduced = orig_expr.subs(self._var_to_val)
assert orig_reduced != sympy.false, f"{orig_expr} is inconsistent!"
# TODO(avik): https://github.com/pytorch/pytorch/issues/101093
# It is possible that `expr` will fail the consistency check because of
# precision errors. Specifically, on substituting its free symbols with
# their concrete values, we might end up comparing floats. Until we have
# a fix for this issue, we delay raising such failures. See solve().
if orig_reduced == sympy.false:
self._inconsistencies.append(f"{orig_expr} is inconsistent!")
free_symbols = expr.free_symbols
assert free_symbols, f"Did not expect constraint with no free variables: {expr}"
if len(free_symbols) > 1:
# multivariate: record and move on
Expand All @@ -1610,7 +1611,11 @@ def add(self, expr):
expr = self.rewrite_with_congruences(s, expr)
if expr != sympy.true:
reduced = expr.subs(self._var_to_val)
assert reduced != sympy.false, f"{expr}, obtained by rewriting {orig_expr} with congruences, is inconsistent!"
if reduced == sympy.false:
self._inconsistencies.append(
f"{expr}, obtained by rewriting {orig_expr} with congruences, "
"is inconsistent!"
)
if isinstance(expr, sympy.Eq):
# special status for symbols that have equalities (see `solve` below)
self._symbols_with_equalities.add(s)
Expand Down Expand Up @@ -1665,7 +1670,14 @@ def reduce_congruences(self):

return reduced_congruences

def raise_inconsistencies(self):
if self._inconsistencies:
msg = "\n".join(self._inconsistencies)
self._inconsistencies.clear()
raise ValueError(f"The following inconsistencies were found:\n{msg}")

def solve(self):
self.raise_inconsistencies()
# as long as there are symbols with equalities, solve for them
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
while(self._symbols_with_equalities):
Expand All @@ -1687,6 +1699,7 @@ def solve(self):
self._multivariate_inequalities = set()
for expr in multivariate_inequalities:
self.add(expr.subs(s, self._substitutions[s]))
self.raise_inconsistencies()

# simplify symbolic equivalences: some of them will now become specializations!
symbolic_equivalences = self._symbolic_equivalences
Expand Down

0 comments on commit 8d6997a

Please sign in to comment.