Skip to content

Commit

Permalink
fix soundness bug with unsupported constraints
Browse files Browse the repository at this point in the history
Differential Revision: [D46415786](https://our.internmc.facebook.com/intern/diff/D46415786/)

ghstack-source-id: 191106011
Pull Request resolved: #102897
  • Loading branch information
avikchaudhuri committed Jun 3, 2023
1 parent 87c976b commit f0202b0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 18 deletions.
25 changes: 21 additions & 4 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,15 +955,32 @@ def result_capturing_wrapper(*graph_inputs):
assert dim_constraints is not None
dim_constraints.solve()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = "\n".join([
(
f"\t{shape_env.var_to_sources[var][0].name()}, which was marked dynamic, "
f"must be specialized to {dim_constraints._substitutions[var]}."
)
for var in shape_env.var_to_range.keys()
if var in dim_constraints._substitutions
])
if forced_specializations:
msg = (
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify.\n"
f"{forced_specializations}\n{msg}"
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)
if forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
else:
log.info(
"Summary of dimension constraints:%s",
msg,
)

# Error if we have any constraints on static values
for k in shape_env.var_to_range.keys():
Expand Down
64 changes: 50 additions & 14 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,35 @@ def raise_inconsistencies(self):
self._inconsistencies.clear()
raise ValueError(f"The following inconsistencies were found:\n{msg}")

def solve(self):
def _force_specialization(self, s):
val = self._var_to_val[s]
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
self._substitutions[s] = val

def specialize_divisor_symbols(self):
for expr in self._multivariate_inequalities:
for atom in expr.atoms(FloorDiv, sympy.Mod):
_, divisor = atom.args
for s in divisor.free_symbols:
self._force_specialization(s)

multivariate_inequalities = self._multivariate_inequalities
self._multivariate_inequalities = set()
for expr in multivariate_inequalities:
self.add(expr.subs(self._substitutions))
self.raise_inconsistencies()
self._univariate_inequalities = {
s: exprs
for s, exprs in self._univariate_inequalities.items()
if s not in self._substitutions
}
self._congruences = {
s: congruences
for s, congruences in self._congruences.items()
if s not in self._substitutions
}

def solve(self, disable_congruences=True):
self.raise_inconsistencies()
# as long as there are symbols with equalities, solve for them
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
Expand All @@ -1711,11 +1739,20 @@ def solve(self):
self.add(expr.subs(s, self._substitutions[s]))
self.raise_inconsistencies()

# simplify symbolic equivalences: some of them will now become specializations!
symbolic_equivalences = self._symbolic_equivalences
self._symbolic_equivalences = []
for source, expr in symbolic_equivalences:
self.add_equality(source, expr.subs(s, self._substitutions[s]))
self.specialize_divisor_symbols()

# solve linear congruences
# NOTE(avik): We do not need to solve them for symbols that have already been specialized.
reduced_congruences = self.reduce_congruences()
for s, congruences in reduced_congruences.items():
for congruence in congruences:
# any congruence that cannot be checked becomes a dynamic constraint as well
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
if disable_congruences:
self._force_specialization(s)
self._univariate_inequalities.pop(s, None)
else:
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(congruence, 0)))

# remaining symbols have only pure inequalities (no equalities)
for s, exprs in self._univariate_inequalities.items():
Expand All @@ -1732,18 +1769,16 @@ def solve(self):
for expr in exprs:
self._dynamic_results.add(self._dcp.doprint(expr))

# simplify symbolic equivalences: some of them will now become specializations!
symbolic_equivalences = self._symbolic_equivalences
self._symbolic_equivalences = []
for source, expr in symbolic_equivalences:
self.add_equality(source, expr.subs(self._substitutions))

# remaining symbolic equivalences become dynamic equality constraints
for source, expr in self._symbolic_equivalences:
self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}")

# solve linear congruences
# NOTE(avik): We do not need to solve them for symbols that have already been specialized.
reduced_congruences = self.reduce_congruences()
for s, congruences in reduced_congruences.items():
for congruence in congruences:
# any congruence that cannot be checked becomes a dynamic constraint as well
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(congruence, 0)))

def prettify_results(self, original_signature: inspect.Signature):
# Note: Model inputs are wrapped as LocalSource in dynamo.
Expand Down Expand Up @@ -2491,6 +2526,7 @@ def hint(s):
elif len(warn_msgs) > 0:
log.debug("%s Warning only constraints violated", len(warn_msgs))

print(f"symbol_to_constraints={symbol_to_constraints}")
return exprs

def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
Expand Down

0 comments on commit f0202b0

Please sign in to comment.