From de026a9fa4abf3d04c72caac8ee28f27457d56c4 Mon Sep 17 00:00:00 2001 From: Yan Yanchii Date: Fri, 10 Jan 2025 13:18:13 +0100 Subject: [PATCH 1/9] Evaluate constant comparions in `fold_compare` --- Python/ast_opt.c | 85 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/Python/ast_opt.c b/Python/ast_opt.c index 01e208b88eca8b..2f60738c90e6c1 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -639,6 +639,91 @@ fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) return 0; } } + + static const int richcompare_table[] = { + [Eq] = Py_EQ, + [NotEq] = Py_NE, + [Gt] = Py_GT, + [Lt] = Py_LT, + [GtE] = Py_GE, + [LtE] = Py_LE, + }; + + if (node->v.Compare.left->kind == Constant_kind) { + PyObject *lhs = node->v.Compare.left->v.Constant.value; + expr_ty curr_expr; + + for (int i=0; i < asdl_seq_LEN(args); i++) { + curr_expr = (expr_ty)asdl_seq_GET(args, i); + + if (curr_expr->kind != Constant_kind) { + goto exit; + } + + PyObject *rhs = curr_expr->v.Constant.value; + int op = asdl_seq_GET(ops, i); + int res; + + switch (op) { + case Eq: case NotEq: + case Gt: case Lt: + case GtE: case LtE: + { + res = PyObject_RichCompareBool(lhs, rhs, richcompare_table[op]); + if (res < 0) { + /* error */ + if (PyErr_Occurred()) { + return make_const(node, NULL, arena); + } + return 0; + } + if (!res) { + /* shortcut, whole expression is False */ + return make_const(node, Py_False, arena); + } + break; + } + case In: + case NotIn: + { + res = PySequence_Contains(rhs, lhs); + if (res < 0) { + /* error */ + if (PyErr_Occurred()) { + return make_const(node, NULL, arena); + } + return 0; + } + if (op == NotIn) { + res = !res; + } + if (!res) { + /* shortcut, whole expression is False */ + return make_const(node, Py_False, arena); + } + break; + } + case Is: + case IsNot: + { + res = Py_Is(lhs, rhs); + if (op == IsNot) { + res = !res; + } + if (!res) { + /* shortcut, whole expression is False */ + return make_const(node, Py_False, arena); + } + break; + } + } + lhs = rhs; + } + /* whole expression is True */ + return make_const(node, Py_True, arena); + } + +exit: return 1; } From 22bcb100f69f700cec05b58aa132ea96dbfeb2d6 Mon Sep 17 00:00:00 2001 From: Yan Yanchii Date: Fri, 10 Jan 2025 17:02:22 +0100 Subject: [PATCH 2/9] do not fold is/isnot --- Python/ast_opt.c | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/Python/ast_opt.c b/Python/ast_opt.c index 2f60738c90e6c1..ceba593c1e12ce 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -650,26 +650,38 @@ fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) }; if (node->v.Compare.left->kind == Constant_kind) { + PyObject *lhs = node->v.Compare.left->v.Constant.value; - expr_ty curr_expr; for (int i=0; i < asdl_seq_LEN(args); i++) { - curr_expr = (expr_ty)asdl_seq_GET(args, i); + + expr_ty curr_expr = (expr_ty)asdl_seq_GET(args, i); if (curr_expr->kind != Constant_kind) { + /* try to fold only if every comparator is constant */ goto exit; } - PyObject *rhs = curr_expr->v.Constant.value; int op = asdl_seq_GET(ops, i); - int res; + + if (op == Is || op == IsNot) { + /* + Do not fold expression for now if "is"/"is not" is present. + It breaks expected syntax warnings. For example: + >>> 1 is 1 + :1: SyntaxWarning: "is" with 'int' literal. Did you mean "=="? + */ + goto exit; + } + + PyObject *rhs = curr_expr->v.Constant.value; switch (op) { case Eq: case NotEq: case Gt: case Lt: case GtE: case LtE: { - res = PyObject_RichCompareBool(lhs, rhs, richcompare_table[op]); + int res = PyObject_RichCompareBool(lhs, rhs, richcompare_table[op]); if (res < 0) { /* error */ if (PyErr_Occurred()) { @@ -686,7 +698,7 @@ fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) case In: case NotIn: { - res = PySequence_Contains(rhs, lhs); + int res = PySequence_Contains(rhs, lhs); if (res < 0) { /* error */ if (PyErr_Occurred()) { @@ -703,19 +715,6 @@ fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) } break; } - case Is: - case IsNot: - { - res = Py_Is(lhs, rhs); - if (op == IsNot) { - res = !res; - } - if (!res) { - /* shortcut, whole expression is False */ - return make_const(node, Py_False, arena); - } - break; - } } lhs = rhs; } From 713fde5a78c901abb76d398ca3f36298b35c2941 Mon Sep 17 00:00:00 2001 From: Yan Yanchii Date: Fri, 10 Jan 2025 17:38:03 +0100 Subject: [PATCH 3/9] fix failing tests due to unexpected comparion folding --- Lib/test/test_ast/test_ast.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index c268a1f00f938e..a392179856ec58 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -3180,7 +3180,8 @@ def create_unaryop(operand): self.assert_ast(result_code, non_optimized_target, optimized_target) def test_folding_not(self): - code = "not (1 %s (1,))" + # use list as left-hand side to avoid folding constant expression to True/False + code = "not ([] %s (1,))" operators = { "in": ast.In(), "is": ast.Is(), @@ -3192,7 +3193,7 @@ def test_folding_not(self): def create_notop(operand): return ast.UnaryOp(op=ast.Not(), operand=ast.Compare( - left=ast.Constant(value=1), + left=ast.List(), ops=[operators[operand]], comparators=[ast.Tuple(elts=[ast.Constant(value=1)])] )) @@ -3201,7 +3202,7 @@ def create_notop(operand): result_code = code % op non_optimized_target = self.wrap_expr(create_notop(op)) optimized_target = self.wrap_expr( - ast.Compare(left=ast.Constant(1), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))]) + ast.Compare(left=ast.List(), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))]) ) with self.subTest( @@ -3239,8 +3240,11 @@ def test_folding_tuple(self): self.assert_ast(code, non_optimized_target, optimized_target) - def test_folding_comparator(self): - code = "1 %s %s1%s" + def test_folding_comparator_list_set_subst(self): + """Test substitution of list/set with tuple/frozenset in expressions like "1 in [1]" or "1 in {1}" """ + + # use list as left-hand side to avoid folding constant comparison expression to True/False + code = "[] %s %s1%s" operators = [("in", ast.In()), ("not in", ast.NotIn())] braces = [ ("[", "]", ast.List, (1,)), @@ -3249,11 +3253,11 @@ def test_folding_comparator(self): for left, right, non_optimized_comparator, optimized_comparator in braces: for op, node in operators: non_optimized_target = self.wrap_expr(ast.Compare( - left=ast.Constant(1), ops=[node], + left=ast.List(), ops=[node], comparators=[non_optimized_comparator(elts=[ast.Constant(1)])] )) optimized_target = self.wrap_expr(ast.Compare( - left=ast.Constant(1), ops=[node], + left=ast.List(), ops=[node], comparators=[ast.Constant(value=optimized_comparator)] )) self.assert_ast(code % (op, left, right), non_optimized_target, optimized_target) From 285f266b64ffc816f3d5b3a1319d53792031a636 Mon Sep 17 00:00:00 2001 From: Yan Yanchii Date: Fri, 10 Jan 2025 18:40:12 +0100 Subject: [PATCH 4/9] add tests for comparison folding --- Lib/test/test_ast/test_ast.py | 47 +++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index a392179856ec58..2bd8aadadfe31a 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -3240,6 +3240,53 @@ def test_folding_tuple(self): self.assert_ast(code, non_optimized_target, optimized_target) + def test_folding_compare(self): + true = self.wrap_expr(ast.Constant(value=True)) + false = self.wrap_expr(ast.Constant(value=False)) + + folded_cases = ( + ("3 > 2 > 1", (ast.Constant(3), [ast.Gt(), ast.Gt()], [ast.Constant(value=2), ast.Constant(value=1)]), true), + ("3 > 4 > 1", (ast.Constant(3), [ast.Gt(), ast.Gt()], [ast.Constant(value=4), ast.Constant(value=1)]), false), + ("3 >= 3 >= 1", (ast.Constant(3), [ast.GtE(), ast.GtE()], [ast.Constant(value=3), ast.Constant(value=1)]), true), + ("3 >= 4 >= 1", (ast.Constant(3), [ast.GtE(), ast.GtE()], [ast.Constant(value=4), ast.Constant(value=1)]), false), + ("1 < 2 < 3", (ast.Constant(1), [ast.Lt(), ast.Lt()], [ast.Constant(value=2), ast.Constant(value=3)]), true), + ("1 < 0 < 3", (ast.Constant(1), [ast.Lt(), ast.Lt()], [ast.Constant(value=0), ast.Constant(value=3)]), false), + ("1 <= 2 <= 3", (ast.Constant(1), [ast.LtE(), ast.LtE()], [ast.Constant(value=2), ast.Constant(value=3)]), true), + ("1 <= 0 <= 3", (ast.Constant(1), [ast.LtE(), ast.LtE()], [ast.Constant(value=0), ast.Constant(value=3)]), false), + ("1 == 1.0 == True", (ast.Constant(1), [ast.Eq(), ast.Eq()], [ast.Constant(value=1.0), ast.Constant(value=True)]), true), + ("1 == 2 == True", (ast.Constant(1), [ast.Eq(), ast.Eq()], [ast.Constant(value=2), ast.Constant(value=True)]), false), + ("1 != 2 != 3", (ast.Constant(1), [ast.NotEq(), ast.NotEq()], [ast.Constant(value=2), ast.Constant(value=3)]), true), + ("1 != 1 != 3", (ast.Constant(1), [ast.NotEq(), ast.NotEq()], [ast.Constant(value=1), ast.Constant(value=3)]), false), + ("1 in [1, 2]", (ast.Constant(1), [ast.In()], [ast.List(elts=[ast.Constant(1), ast.Constant(2)])]), true), + ("1 in [2, 2]", (ast.Constant(1), [ast.In()], [ast.List(elts=[ast.Constant(2), ast.Constant(2)])]), false), + ("1 not in [1, 2]", (ast.Constant(1), [ast.NotIn()], [ast.List(elts=[ast.Constant(1), ast.Constant(2)])]), false), + ("1 not in [2, 2]", (ast.Constant(1), [ast.NotIn()], [ast.List(elts=[ast.Constant(2), ast.Constant(2)])]), true), + ) + + for code, original, folded in folded_cases: + left, ops, comparators = original + unfolded = self.wrap_expr(ast.Compare(left=left, ops=ops, comparators=comparators)) + self.assert_ast(code=code, non_optimized_target=unfolded, optimized_target=folded) + + # these should stay as they were + unfolded_cases = ( + ("3 > 2 > []", ast.Compare(left=ast.Constant(3), ops=[ast.Gt(), ast.Gt()], comparators=[ast.Constant(2), ast.List()])), + ("1 > [] > 0", ast.Compare(left=ast.Constant(1), ops=[ast.Gt(), ast.Gt()], comparators=[ast.List(), ast.Constant(0)])), + ("1 >= [] >= 0", ast.Compare(left=ast.Constant(1), ops=[ast.GtE(), ast.GtE()], comparators=[ast.List(), ast.Constant(0)])), + ("1 < [] < 0", ast.Compare(left=ast.Constant(1), ops=[ast.Lt(), ast.Lt()], comparators=[ast.List(), ast.Constant(0)])), + ("1 <= [] <= 0", ast.Compare(left=ast.Constant(1), ops=[ast.LtE(), ast.LtE()], comparators=[ast.List(), ast.Constant(0)])), + ("1 == [] == 0", ast.Compare(left=ast.Constant(1), ops=[ast.Eq(), ast.Eq()], comparators=[ast.List(), ast.Constant(0)])), + ("1 != [] != 0", ast.Compare(left=ast.Constant(1), ops=[ast.NotEq(), ast.NotEq()], comparators=[ast.List(), ast.Constant(0)])), + ("1 is 1", ast.Compare(left=ast.Constant(1), ops=[ast.Is()], comparators=[ast.Constant(1)])), + ("1 is not 1", ast.Compare(left=ast.Constant(1), ops=[ast.IsNot()], comparators=[ast.Constant(1)])), + # invalid also should stay as they were + ("1 in 1", ast.Compare(left=ast.Constant(1), ops=[ast.In()], comparators=[ast.Constant(1)])), + ("1 not in 1", ast.Compare(left=ast.Constant(1), ops=[ast.NotIn()], comparators=[ast.Constant(1)])), + ) + + for code, expected in unfolded_cases: + self.assertTrue(ast.compare(ast.parse(code), self.wrap_expr(expected))) + def test_folding_comparator_list_set_subst(self): """Test substitution of list/set with tuple/frozenset in expressions like "1 in [1]" or "1 in {1}" """ From c3abf4a9f4b86b79feeb8c96bd045db3d324e7a5 Mon Sep 17 00:00:00 2001 From: Yan Yanchii Date: Fri, 10 Jan 2025 18:43:38 +0100 Subject: [PATCH 5/9] add news entry --- .../2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst new file mode 100644 index 00000000000000..0f6ae40c08b38b --- /dev/null +++ b/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst @@ -0,0 +1 @@ +Add constant folding for constant comparisons From 5cb7cd87b525fe9d697796eaefc8d9af71146b4e Mon Sep 17 00:00:00 2001 From: Yan Yanchii Date: Sat, 11 Jan 2025 13:34:54 +0100 Subject: [PATCH 6/9] address comments --- Python/ast_opt.c | 79 ++++++++++++++++-------------------------------- 1 file changed, 26 insertions(+), 53 deletions(-) diff --git a/Python/ast_opt.c b/Python/ast_opt.c index ceba593c1e12ce..edf1edd63de000 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -650,79 +650,52 @@ fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) }; if (node->v.Compare.left->kind == Constant_kind) { - PyObject *lhs = node->v.Compare.left->v.Constant.value; - - for (int i=0; i < asdl_seq_LEN(args); i++) { - + for (Py_ssize_t i=0; i < asdl_seq_LEN(args); i++) { expr_ty curr_expr = (expr_ty)asdl_seq_GET(args, i); - if (curr_expr->kind != Constant_kind) { /* try to fold only if every comparator is constant */ - goto exit; + return 1; } - int op = asdl_seq_GET(ops, i); - if (op == Is || op == IsNot) { - /* - Do not fold expression for now if "is"/"is not" is present. - It breaks expected syntax warnings. For example: - >>> 1 is 1 - :1: SyntaxWarning: "is" with 'int' literal. Did you mean "=="? - */ - goto exit; + /* Do not fold "is" and "is not" expressions since this breaks + expected syntax warnings. For example: + >>> 1 is 1 + :1: SyntaxWarning: "is" with 'int' literal. Did you mean "=="? + */ + return 1; } - PyObject *rhs = curr_expr->v.Constant.value; - + int res; switch (op) { - case Eq: case NotEq: - case Gt: case Lt: - case GtE: case LtE: - { - int res = PyObject_RichCompareBool(lhs, rhs, richcompare_table[op]); - if (res < 0) { - /* error */ - if (PyErr_Occurred()) { - return make_const(node, NULL, arena); - } - return 0; - } - if (!res) { - /* shortcut, whole expression is False */ - return make_const(node, Py_False, arena); - } + case Eq: + case NotEq: + case Gt: + case Lt: + case GtE: + case LtE: + res = PyObject_RichCompareBool(lhs, rhs, richcompare_table[op]); break; - } case In: case NotIn: - { - int res = PySequence_Contains(rhs, lhs); - if (res < 0) { - /* error */ - if (PyErr_Occurred()) { - return make_const(node, NULL, arena); - } - return 0; - } - if (op == NotIn) { - res = !res; - } - if (!res) { - /* shortcut, whole expression is False */ - return make_const(node, Py_False, arena); - } + res = PySequence_Contains(rhs, lhs); + if (op == NotIn && res >= 0) res = !res; break; - } + default: + Py_UNREACHABLE(); + } + if (res == 0) { + /* shortcut, whole expression is False */ + return make_const(node, Py_False, arena); + } else if (res < 0) { + return make_const(node, NULL, arena); } lhs = rhs; } /* whole expression is True */ return make_const(node, Py_True, arena); } - -exit: return 1; } From 65093fb10d4037ab21aad058d187845c132b614e Mon Sep 17 00:00:00 2001 From: Yan Yanchii Date: Sat, 11 Jan 2025 14:38:58 +0100 Subject: [PATCH 7/9] address comments --- .../2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst | 2 +- Python/ast_opt.c | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst index 0f6ae40c08b38b..1949a649d9f748 100644 --- a/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst +++ b/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst @@ -1 +1 @@ -Add constant folding for constant comparisons +Add constant folding for constant comparisons. diff --git a/Python/ast_opt.c b/Python/ast_opt.c index edf1edd63de000..5914ed396aa954 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -674,21 +674,26 @@ fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) case Gt: case Lt: case GtE: - case LtE: + case LtE: { res = PyObject_RichCompareBool(lhs, rhs, richcompare_table[op]); break; + } case In: - case NotIn: + case NotIn: { res = PySequence_Contains(rhs, lhs); - if (op == NotIn && res >= 0) res = !res; + if (op == NotIn && res >= 0) { + res = !res; + } break; + } default: Py_UNREACHABLE(); } if (res == 0) { /* shortcut, whole expression is False */ return make_const(node, Py_False, arena); - } else if (res < 0) { + } + else if (res < 0) { return make_const(node, NULL, arena); } lhs = rhs; From 891bff5564694b9a19ff2f1eee59f4768de90ce6 Mon Sep 17 00:00:00 2001 From: Yan Yanchii Date: Sat, 11 Jan 2025 14:39:29 +0100 Subject: [PATCH 8/9] add whats new entry --- Doc/whatsnew/3.14.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Doc/whatsnew/3.14.rst b/Doc/whatsnew/3.14.rst index 72abfebd46f2b9..36d50ef10e8296 100644 --- a/Doc/whatsnew/3.14.rst +++ b/Doc/whatsnew/3.14.rst @@ -210,6 +210,10 @@ configuration mechanisms). Other language changes ====================== +* Constant comparsion expressions are now folded and evaluated before runtime. + For example, expressions like: ``"str" in ("str",)`` or ``1 == 1.0 == True`` + are now pre-evaluated. + (Contributed by Yan Yanchii in :gh:`128706`.) * The :func:`map` built-in now has an optional keyword-only *strict* flag like :func:`zip` to check that all the iterables are of equal length. From 4f5f1b5d4abf355107857e8632eb4e2b27409541 Mon Sep 17 00:00:00 2001 From: Yan Yanchii Date: Sat, 11 Jan 2025 14:40:33 +0100 Subject: [PATCH 9/9] address coments --- Python/ast_opt.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Python/ast_opt.c b/Python/ast_opt.c index 5914ed396aa954..06034f2d805832 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -651,7 +651,7 @@ fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) if (node->v.Compare.left->kind == Constant_kind) { PyObject *lhs = node->v.Compare.left->v.Constant.value; - for (Py_ssize_t i=0; i < asdl_seq_LEN(args); i++) { + for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) { expr_ty curr_expr = (expr_ty)asdl_seq_GET(args, i); if (curr_expr->kind != Constant_kind) { /* try to fold only if every comparator is constant */