Skip to content

Commit

Permalink
[export] disable forced specializations, even when solvable with sing…
Browse files Browse the repository at this point in the history
…le var (#126925)

Summary:
Previously #124949 added the ability to disable forced specializations on dynamic shapes for export, keeping dynamism for complex guards instead of specializing, allowing unsoundness by having the user fail at runtime.

It avoided disabling one case: single-variable equality guards, where a variable is specified as dynamic but can be solvable for a concrete value, suggesting the correct behavior is specialization. For example, guard : Eq(s0 // 4, 400) suggests s0 should specialize to 1600.

In debugging, some users (e.g. APS) would like to keep this dynamic, and defer to failing at runtime instead. This PR adds this, so now all forced specializations should be turned off. Mostly this should be used for debugging, since it produces unsoundness, and lets the user proceed with (probably) incorrect dynamism.

Test Plan: export tests

Differential Revision: D57698601

Pull Request resolved: #126925
Approved by: https://github.com/angelayi
  • Loading branch information
pianpwk authored and pytorchmergebot committed May 23, 2024
1 parent 6eac3f4 commit 2db1363
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,10 +1774,12 @@ def solve(
assert isinstance(solution, sympy.Eq), f"Expected an equality constraint for {s}, got {solution}"
symbol, val = solution.args
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
# because this is univariate, the solution is a specialization
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
# add this as a substitution to simplify other constraints
self._substitutions[s] = val
# really don't force specializations here
if not (_disable_forced_specializations and s in self._marked_dynamic):
# because this is univariate, the solution is a specialization
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
# add this as a substitution to simplify other constraints
self._substitutions[s] = val

# simplify multivariate inequalities: some of them will now become univariate!
multivariate_inequalities = self._multivariate_inequalities
Expand Down

0 comments on commit 2db1363

Please sign in to comment.