Skip to content

Commit

Permalink
Simplify guards using info from previous guards (#121463)
Browse files Browse the repository at this point in the history
Let me see what CI thinks about this one. Will add tests tomorrow.

Fixes #119917
Pull Request resolved: #121463
Approved by: https://github.com/ezyang
  • Loading branch information
lezcano authored and pytorchmergebot committed Mar 12, 2024
1 parent 703e83e commit 86a2d67
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
8 changes: 5 additions & 3 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,13 +500,15 @@ def test_expect_true_with_s0(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 5)
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(i0 <= s0))
self.assertTrue(expect_true(i0 < s0))
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[-s0 + u0 <= 0]"""
"""[-s0 + u0 < 0]"""
)
self.assertTrue(i0 <= s0)
self.assertTrue(i0 < s0)
self.assertTrue(i0 != s0)
self.assertFalse(i0 > s0)
self.assertFalse(i0 >= s0)

def test_expect_true_prefer_later(self):
shape_env = ShapeEnv()
Expand Down
35 changes: 23 additions & 12 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,7 +1927,7 @@ def _init(
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
# Maps from sympy ints to expressions representing them
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
self.replacements: Dict[sympy.Symbol, sympy.Expr] = {} #
self.replacements: Dict[sympy.Symbol, sympy.Expr] = {}
# Set holds a % b expressions that evaluate to 0.
self.divisible: Set[sympy.Expr] = set()
# Set that holds "size-like" symbols. When we perform
Expand Down Expand Up @@ -3486,20 +3486,31 @@ def _maybe_evaluate_static(
# Unbacked symints only
if s in self.var_to_val:
continue

subst = {}
for ra in self.deferred_runtime_asserts.get(s, ()):

def add_expr(expr):
# Expr and negation
subst[canonicalize_bool_expr(expr)] = sympy.true
subst[canonicalize_bool_expr(sympy.Not(expr))] = sympy.false
if isinstance(expr, sympy.Rel):
# multiplying by -1 changes the direction of the inequality
dual = type(expr)(-expr.rhs, -expr.lhs)
subst[canonicalize_bool_expr(dual)] = sympy.true
subst[canonicalize_bool_expr(sympy.Not(dual))] = sympy.false

for e in itertools.chain(self.guards, self.deferred_runtime_asserts.get(s, ())):
e = e.expr
if compute_hint:
e = canonicalize_bool_expr(ra.expr.xreplace(self.var_to_val))
else:
e = ra.expr
# e is already canonical
subst[e] = sympy.true
subst[canonicalize_bool_expr(sympy.Not(e))] = sympy.false
e = canonicalize_bool_expr(e.xreplace(self.var_to_val))
add_expr(e)
# Other relational expressions this expression implies
if isinstance(e, sympy.Eq):
subst[sympy.Le(e.lhs, e.rhs)] = sympy.true
subst[sympy.Le(-e.lhs, -e.rhs)] = sympy.true
subst[sympy.Lt(e.lhs, e.rhs)] = sympy.false
subst[sympy.Lt(-e.lhs, -e.rhs)] = sympy.false
add_expr(sympy.Le(e.lhs, e.rhs))
add_expr(sympy.Ge(e.lhs, e.rhs))
elif isinstance(e, sympy.Lt):
add_expr(sympy.Le(e.lhs, e.rhs))
add_expr(sympy.Ne(e.lhs, e.rhs))

# NB: this helps us deal with And/Or connectives
expr = expr.subs(subst)
Expand Down

0 comments on commit 86a2d67

Please sign in to comment.