Skip to content

Commit

Permalink
fix soundness bug with unsupported constraints
Browse files Browse the repository at this point in the history
Pull Request resolved: #102897


ghstack-source-id: 191208582

Differential Revision: [D46415786](https://our.internmc.facebook.com/intern/diff/D46415786/)
  • Loading branch information
avikchaudhuri committed Jun 5, 2023
1 parent 6b8e68c commit 4dae454
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 22 deletions.
24 changes: 24 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,30 @@ def f(x, y):
):
export(f, example_inputs, constraints)

def test_export_mod_constraints(self):
class BasicDynamiShapeModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.view(x.shape[0] - 1, -1)

m = BasicDynamiShapeModel()
a = torch.randn(3, 4)
constraints = [3 <= dynamic_dim(a, 0), dynamic_dim(a, 1)]
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify"
".*\n.*\\[0\\], which was marked dynamic, must be specialized to 3"
".*\n.*\\[1\\], which was marked dynamic, must be specialized to 4"
),
):
torch._export.export(m, (a,), constraints=constraints)
em = torch._export.export(m, (a,)).add_runtime_assertions()
x = torch.randn(3, 5)
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
em(x)



if __name__ == '__main__':
run_tests()
5 changes: 3 additions & 2 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def test_dim_constraints_reduce_congruences_simple(self):
from torch.fx.experimental.symbolic_shapes import DimConstraints

s = Symbol("s", positive=True, integer=True)
dim_constraints = DimConstraints({}, {})
dim_constraints = DimConstraints({}, {}, set())
dim_constraints._congruences[s] = {
(s / 2) % 2,
(s / 2) % 8,
Expand Down Expand Up @@ -933,7 +933,8 @@ def test_dim_constraints_solve_full(self):
s6: [src6, src9, src10],
}
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
dim_constraints = DimConstraints(symbol_to_source, var_to_val)
marked_dynamic = {s0, s1, s5, s6}
dim_constraints = DimConstraints(symbol_to_source, var_to_val, marked_dynamic)
dim_constraints.add_equality(src2, s0)
dim_constraints.add_equality(src3, s0)
dim_constraints.add_equality(src4, s0)
Expand Down
18 changes: 14 additions & 4 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,15 +955,25 @@ def result_capturing_wrapper(*graph_inputs):
assert dim_constraints is not None
dim_constraints.solve()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = dim_constraints.forced_specializations()
if forced_specializations:
msg = (
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify.\n"
f"{forced_specializations}\n{msg}"
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)

# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
Expand Down
83 changes: 67 additions & 16 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,7 @@ class DimConstraints:
Solutions are "static" values or simplified "dynamic" constraints.
"""

def __init__(self, symbol_to_source, var_to_val):
def __init__(self, symbol_to_source, var_to_val, marked_dynamic):
# We try to solve systems of inequalities with 1 free variable.
self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set)
# Among them, we prioritize solving for a free variable that has equalities.
Expand Down Expand Up @@ -1538,6 +1538,9 @@ def __init__(self, symbol_to_source, var_to_val):
# inconsistencies found on substituting with concrete values / static solutions
self._inconsistencies: List[str] = []

# symbols that are marked dynamic
self._marked_dynamic = marked_dynamic

def rewrite_with_congruences(self, s, expr):
"""
Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
Expand Down Expand Up @@ -1686,7 +1689,35 @@ def raise_inconsistencies(self):
self._inconsistencies.clear()
raise ValueError(f"The following inconsistencies were found:\n{msg}")

def solve(self):
def _force_specialization(self, s):
val = self._var_to_val[s]
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
self._substitutions[s] = val

def specialize_divisor_symbols(self):
for expr in self._multivariate_inequalities:
for atom in expr.atoms(FloorDiv, sympy.Mod):
_, divisor = atom.args
for s in divisor.free_symbols:
self._force_specialization(s)

multivariate_inequalities = self._multivariate_inequalities
self._multivariate_inequalities = set()
for expr in multivariate_inequalities:
self.add(expr.subs(self._substitutions))
self.raise_inconsistencies()
self._univariate_inequalities = {
s: exprs
for s, exprs in self._univariate_inequalities.items()
if s not in self._substitutions
}
self._congruences = {
s: congruences
for s, congruences in self._congruences.items()
if s not in self._substitutions
}

def solve(self, disable_congruences=True):
self.raise_inconsistencies()
# as long as there are symbols with equalities, solve for them
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
Expand All @@ -1711,11 +1742,20 @@ def solve(self):
self.add(expr.subs(s, self._substitutions[s]))
self.raise_inconsistencies()

# simplify symbolic equivalences: some of them will now become specializations!
symbolic_equivalences = self._symbolic_equivalences
self._symbolic_equivalences = []
for source, expr in symbolic_equivalences:
self.add_equality(source, expr.subs(s, self._substitutions[s]))
self.specialize_divisor_symbols()

# solve linear congruences
# NOTE(avik): We do not need to solve them for symbols that have already been specialized.
reduced_congruences = self.reduce_congruences()
for s, congruences in reduced_congruences.items():
for congruence in congruences:
# any congruence that cannot be checked becomes a dynamic constraint as well
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
if disable_congruences:
self._force_specialization(s)
self._univariate_inequalities.pop(s, None)
else:
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(congruence, 0)))

# remaining symbols have only pure inequalities (no equalities)
for s, exprs in self._univariate_inequalities.items():
Expand All @@ -1732,18 +1772,25 @@ def solve(self):
for expr in exprs:
self._dynamic_results.add(self._dcp.doprint(expr))

# simplify symbolic equivalences: some of them will now become specializations!
symbolic_equivalences = self._symbolic_equivalences
self._symbolic_equivalences = []
for source, expr in symbolic_equivalences:
self.add_equality(source, expr.subs(self._substitutions))

# remaining symbolic equivalences become dynamic equality constraints
for source, expr in self._symbolic_equivalences:
self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}")

# solve linear congruences
# NOTE(avik): We do not need to solve them for symbols that have already been specialized.
reduced_congruences = self.reduce_congruences()
for s, congruences in reduced_congruences.items():
for congruence in congruences:
# any congruence that cannot be checked becomes a dynamic constraint as well
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(congruence, 0)))
def forced_specializations(self):
return "\n".join([
(
f"\t{self._dcp.symbol_to_source[s][0].name()}, which was marked dynamic, "
f"must be specialized to {val}."
)
for s, val in self._substitutions.items()
if s in self._marked_dynamic
])

def prettify_results(self, original_signature: inspect.Signature):
# Note: Model inputs are wrapped as LocalSource in dynamo.
Expand Down Expand Up @@ -2347,7 +2394,11 @@ def hint(s):
# if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
# This does a lot of work: it covers duck sizing and equality guards.
exprs = []
self.dim_constraints = DimConstraints(symbol_to_source, self.var_to_val)
self.dim_constraints = DimConstraints(
symbol_to_source,
self.var_to_val,
set(symbol_to_constraints.keys()),
)

if not _simplified:
for source, expr in input_guards:
Expand Down

0 comments on commit 4dae454

Please sign in to comment.