Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-45843: Optimize constant comparisons/contains #29639

Closed
@@ -0,0 +1 @@
Optimize constant comparisons/contains. Patch by Jeremiah Vivian.
205 changes: 187 additions & 18 deletions Python/ast_opt.c
Expand Up @@ -575,7 +575,7 @@ fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
}

/* Change literal list or set of constants into constant
tuple or frozenset respectively. Change literal list of
tuple or frozenset respectively. Change literal list of
non-constants into tuple.
Used for right operand of "in" and "not in" tests and for iterable
in "for" loop and comprehensions.
Expand All @@ -585,11 +585,12 @@ fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
{
PyObject *newval;
if (arg->kind == List_kind) {
/* First change a list into tuple. */
/* First change a list into tuple.
We don't need a check here, because make_const_tuple
returns NULL and make_const returns 1 if its second
argument is NULL.
*/
asdl_expr_seq *elts = arg->v.List.elts;
if (has_starred(elts)) {
return 1;
}
expr_context_ty ctx = arg->v.List.ctx;
arg->kind = Tuple_kind;
arg->v.Tuple.elts = elts;
Expand All @@ -609,25 +610,193 @@ fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
return make_const(arg, newval, arena);
}

/* Fold all constant comparisons and contains by pointer
logic and PyObject_RichCompareBool (PySequence_Contains for
"in" or "not in"). Short circuit all remaining values
if the comparison returns 0 (False).
If there are constant comparisons, set node to a "and" chain
containing all the non-constant comparisons.
If there are no non-constant comparisons or no non-constant
comparisons are found before a constant comparison returns False,
set a boolean instead.
*/
static int
fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
{
asdl_int_seq *ops;
asdl_expr_seq *args;
Py_ssize_t i;

ops = node->v.Compare.ops;
args = node->v.Compare.comparators;
/* TODO: optimize cases with literal arguments. */
/* Change literal list or set in 'in' or 'not in' into
tuple or frozenset respectively. */
i = asdl_seq_LEN(ops) - 1;
int op = asdl_seq_GET(ops, i);
if (op == In || op == NotIn) {
if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
asdl_int_seq *cut_ops;
asdl_expr_seq *result, *cut_args;
expr_ty left, right, node_copy, bool_const;
expr_ty *args;
Py_ssize_t i, j, real_index = 0, unchanged = 0;
Py_ssize_t ops_length, args_size, ops_size;
int res, op, has_unchanged;
int *ops;

ops = node->v.Compare.ops->typed_elements;
args = node->v.Compare.comparators->typed_elements;
assert(args != NULL && ops != NULL);
ops_length = asdl_seq_LEN(node->v.Compare.ops);
args_size = ops_size = ops_length;
result = _Py_asdl_expr_seq_new(ops_length, arena);
/* Iterate over each comparison. */
for (i = 0; i < ops_length; i++) {
right = args[0];
op = (ops_size == ops_length) ? ops[i] : ops[0];
if (i != 0) {
left = (args_size == ops_length) ? args[i] : args[-1];
}
else {
left = node->v.Compare.left;
}
/* If possible and the operator is "in" or "not in",
convert lists or sets to constant tuples or frozensets.
*/
if (op == In || op == NotIn) {
if (!fold_iter(right, arena, state)) {
return 0;
}
}
if (left->kind != Constant_kind || right->kind != Constant_kind) {
unchanged++;
result->size++;
continue;
}
switch (op) {
case In:
case NotIn:
res = PySequence_Contains(
right->v.Constant.value,
left->v.Constant.value
) ^ (op == NotIn);
break;
case Is:
case IsNot:
res = Py_Is(
left->v.Constant.value,
right->v.Constant.value
) ^ (op == IsNot);
break;
case Eq:
res = PyObject_RichCompareBool(
left->v.Constant.value,
right->v.Constant.value,
Py_EQ
);
break;
case NotEq:
res = PyObject_RichCompareBool(
left->v.Constant.value,
right->v.Constant.value,
Py_NE
);
break;
case Lt:
res = PyObject_RichCompareBool(
left->v.Constant.value,
right->v.Constant.value,
Py_LT
);
break;
case Gt:
res = PyObject_RichCompareBool(
left->v.Constant.value,
right->v.Constant.value,
Py_GT
);
break;
case LtE:
res = PyObject_RichCompareBool(
left->v.Constant.value,
right->v.Constant.value,
Py_LE
);
break;
case GtE:
res = PyObject_RichCompareBool(
left->v.Constant.value,
right->v.Constant.value,
Py_GE
);
break;
}
if (res == -1) {
return 0;
}
else if (res == 0) {
/* Short circuit. If there is a non-constant
value, make False a part of the "and" chain
and break out of the loop, disregarding all the
other comparisons after.
If all values before are constants, set node to
False and return.
*/
if (has_unchanged) {
bool_const = _PyAST_Constant(
Py_False, NULL,
node->lineno, node->col_offset,
node->end_lineno, node->end_col_offset,
arena
);
asdl_seq_SET(result, real_index, bool_const);
real_index++;
break;
}
else {
Py_INCREF(Py_False);
return make_const(node, Py_False, arena);
}
}
else {
/* Handle non-constant values. */
if (unchanged != 0) {
cut_args = _Py_asdl_expr_seq_new(unchanged, arena);
cut_ops = _Py_asdl_int_seq_new(unchanged, arena);
j = (args_size == ops_length);
if (j == 1) {
asdl_seq_SET(cut_args, 0, node->v.Compare.left);
asdl_seq_SET(cut_ops, 0, ops[0]);
}
for (; j < unchanged; j++) {
asdl_seq_SET(
cut_args,
j,
args[j-1]
);
asdl_seq_SET(cut_ops, j, ops[j]);
}
node_copy = _PyAST_Compare(
left, cut_ops,
cut_args, node->lineno,
node->col_offset, node->end_lineno,
node->end_col_offset, arena
);
asdl_seq_SET(result, real_index, node_copy);
result->size++;
real_index++;
has_unchanged = 1;
}
unchanged++;
args += unchanged;
args_size -= unchanged;
ops += unchanged;
ops_size -= unchanged;
unchanged = 0;
}
}
if (!has_unchanged) {
/* All comparisons are constants and no
short-circuiting happened, so set node to True.
*/
Py_INCREF(Py_True);
return make_const(node, Py_True, arena);
}
if (real_index == 0) {
return 1;
}
result->size = real_index;
node->kind = BoolOp_kind;
node->v.BoolOp.op = And;
node->v.BoolOp.values = result;
return 1;
}

Expand Down