diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index ba8818da42981..68ec33d863301 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -500,13 +500,15 @@ def test_expect_true_with_s0(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 5) i0 = shape_env.create_unbacked_symint() - self.assertTrue(expect_true(i0 <= s0)) + self.assertTrue(expect_true(i0 < s0)) self.assertExpectedInline( str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]), - """[-s0 + u0 <= 0]""" + """[-s0 + u0 < 0]""" ) - self.assertTrue(i0 <= s0) + self.assertTrue(i0 < s0) + self.assertTrue(i0 != s0) self.assertFalse(i0 > s0) + self.assertFalse(i0 >= s0) def test_expect_true_prefer_later(self): shape_env = ShapeEnv() diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 34c8f14379416..e9b7f56f97488 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1927,7 +1927,7 @@ def _init( self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {} # Maps from sympy ints to expressions representing them # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) - self.replacements: Dict[sympy.Symbol, sympy.Expr] = {} # + self.replacements: Dict[sympy.Symbol, sympy.Expr] = {} # Set holds a % b expressions that evaluate to 0. self.divisible: Set[sympy.Expr] = set() # Set that holds "size-like" symbols. When we perform @@ -3486,20 +3486,31 @@ def _maybe_evaluate_static( # Unbacked symints only if s in self.var_to_val: continue + subst = {} - for ra in self.deferred_runtime_asserts.get(s, ()): + + def add_expr(expr): + # Expr and negation + subst[canonicalize_bool_expr(expr)] = sympy.true + subst[canonicalize_bool_expr(sympy.Not(expr))] = sympy.false + if isinstance(expr, sympy.Rel): + # multiplying by -1 changes the direction of the inequality + dual = type(expr)(-expr.rhs, -expr.lhs) + subst[canonicalize_bool_expr(dual)] = sympy.true + subst[canonicalize_bool_expr(sympy.Not(dual))] = sympy.false + + for e in itertools.chain(self.guards, self.deferred_runtime_asserts.get(s, ())): + e = e.expr if compute_hint: - e = canonicalize_bool_expr(ra.expr.xreplace(self.var_to_val)) - else: - e = ra.expr - # e is already canonical - subst[e] = sympy.true - subst[canonicalize_bool_expr(sympy.Not(e))] = sympy.false + e = canonicalize_bool_expr(e.xreplace(self.var_to_val)) + add_expr(e) + # Other relational expressions this expression implies if isinstance(e, sympy.Eq): - subst[sympy.Le(e.lhs, e.rhs)] = sympy.true - subst[sympy.Le(-e.lhs, -e.rhs)] = sympy.true - subst[sympy.Lt(e.lhs, e.rhs)] = sympy.false - subst[sympy.Lt(-e.lhs, -e.rhs)] = sympy.false + add_expr(sympy.Le(e.lhs, e.rhs)) + add_expr(sympy.Ge(e.lhs, e.rhs)) + elif isinstance(e, sympy.Lt): + add_expr(sympy.Le(e.lhs, e.rhs)) + add_expr(sympy.Ne(e.lhs, e.rhs)) # NB: this helps us deal with And/Or connectives expr = expr.subs(subst)