Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 65 additions & 63 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,24 +701,70 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
# x in (...)/[...]
# x not in (...)/[...]
first_op = e.operators[0]
if (
first_op in ["in", "not in"]
and len(e.operators) == 1
and isinstance(e.operands[1], (TupleExpr, ListExpr))
):
items = e.operands[1].items
if first_op in ["in", "not in"] and len(e.operators) == 1:
result = try_specialize_in_expr(builder, first_op, e.operands[0], e.operands[1], e.line)
if result is not None:
return result

if len(e.operators) == 1:
# Special some common simple cases
if first_op in ("is", "is not"):
right_expr = e.operands[1]
if isinstance(right_expr, NameExpr) and right_expr.fullname == "builtins.None":
# Special case 'is None' / 'is not None'.
return translate_is_none(builder, e.operands[0], negated=first_op != "is")
left_expr = e.operands[0]
if is_int_rprimitive(builder.node_type(left_expr)):
right_expr = e.operands[1]
if is_int_rprimitive(builder.node_type(right_expr)):
if first_op in int_borrow_friendly_op:
borrow_left = is_borrow_friendly_expr(builder, right_expr)
left = builder.accept(left_expr, can_borrow=borrow_left)
right = builder.accept(right_expr, can_borrow=True)
return builder.binary_op(left, right, first_op, e.line)

# TODO: Don't produce an expression when used in conditional context
# All of the trickiness here is due to support for chained conditionals
# (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
# `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
expr_type = builder.node_type(e)

# go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
# assuming that prev contains the value of `ei`.
def go(i: int, prev: Value) -> Value:
if i == len(e.operators) - 1:
return transform_basic_comparison(
builder, e.operators[i], prev, builder.accept(e.operands[i + 1]), e.line
)

next = builder.accept(e.operands[i + 1])
return builder.builder.shortcircuit_helper(
"and",
expr_type,
lambda: transform_basic_comparison(builder, e.operators[i], prev, next, e.line),
lambda: go(i + 1, next),
e.line,
)

return go(0, builder.accept(e.operands[0]))


def try_specialize_in_expr(
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
) -> Value | None:
if isinstance(rhs, (TupleExpr, ListExpr)):
items = rhs.items
n_items = len(items)
# x in y -> x == y[0] or ... or x == y[n]
# x not in y -> x != y[0] and ... and x != y[n]
# 16 is arbitrarily chosen to limit code size
if 1 < n_items < 16:
if e.operators[0] == "in":
if op == "in":
bin_op = "or"
cmp_op = "=="
else:
bin_op = "and"
cmp_op = "!="
lhs = e.operands[0]
mypy_file = builder.graph["builtins"].tree
assert mypy_file is not None
info = mypy_file.names["bool"].node
Expand All @@ -738,78 +784,34 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
# x in [y]/(y) -> x == y
# x not in [y]/(y) -> x != y
elif n_items == 1:
if e.operators[0] == "in":
if op == "in":
cmp_op = "=="
else:
cmp_op = "!="
e.operators = [cmp_op]
e.operands[1] = items[0]
left = builder.accept(lhs)
right = builder.accept(items[0])
return transform_basic_comparison(builder, cmp_op, left, right, line)
# x in []/() -> False
# x not in []/() -> True
elif n_items == 0:
if e.operators[0] == "in":
if op == "in":
return builder.false()
else:
return builder.true()

# x in {...}
# x not in {...}
if (
first_op in ("in", "not in")
and len(e.operators) == 1
and isinstance(e.operands[1], SetExpr)
):
set_literal = precompute_set_literal(builder, e.operands[1])
if isinstance(rhs, SetExpr):
set_literal = precompute_set_literal(builder, rhs)
if set_literal is not None:
lhs = e.operands[0]
result = builder.builder.primitive_op(
set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive
set_in_op, [builder.accept(lhs), set_literal], line, bool_rprimitive
)
if first_op == "not in":
return builder.unary_op(result, "not", e.line)
if op == "not in":
return builder.unary_op(result, "not", line)
return result

if len(e.operators) == 1:
# Special some common simple cases
if first_op in ("is", "is not"):
right_expr = e.operands[1]
if isinstance(right_expr, NameExpr) and right_expr.fullname == "builtins.None":
# Special case 'is None' / 'is not None'.
return translate_is_none(builder, e.operands[0], negated=first_op != "is")
left_expr = e.operands[0]
if is_int_rprimitive(builder.node_type(left_expr)):
right_expr = e.operands[1]
if is_int_rprimitive(builder.node_type(right_expr)):
if first_op in int_borrow_friendly_op:
borrow_left = is_borrow_friendly_expr(builder, right_expr)
left = builder.accept(left_expr, can_borrow=borrow_left)
right = builder.accept(right_expr, can_borrow=True)
return builder.binary_op(left, right, first_op, e.line)

# TODO: Don't produce an expression when used in conditional context
# All of the trickiness here is due to support for chained conditionals
# (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
# `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
expr_type = builder.node_type(e)

# go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
# assuming that prev contains the value of `ei`.
def go(i: int, prev: Value) -> Value:
if i == len(e.operators) - 1:
return transform_basic_comparison(
builder, e.operators[i], prev, builder.accept(e.operands[i + 1]), e.line
)

next = builder.accept(e.operands[i + 1])
return builder.builder.shortcircuit_helper(
"and",
expr_type,
lambda: transform_basic_comparison(builder, e.operators[i], prev, next, e.line),
lambda: go(i + 1, next),
e.line,
)

return go(0, builder.accept(e.operands[0]))
return None


def translate_is_none(builder: IRBuilder, expr: Expression, negated: bool) -> Value:
Expand Down
30 changes: 30 additions & 0 deletions mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,36 @@ L5:
L6:
return r3

[case testTupleOperatorInFinalTuple]
from typing import Final

tt: Final = (1, 2)

def f(x: int) -> bool:
return x in tt
[out]
def f(x):
x :: int
r0 :: tuple[int, int]
r1 :: bool
r2, r3 :: object
r4 :: i32
r5 :: bit
r6 :: bool
L0:
r0 = __main__.tt :: static
if is_error(r0) goto L1 else goto L2
L1:
r1 = raise NameError('value for final name "tt" was not set')
unreachable
L2:
r2 = box(int, x)
r3 = box(tuple[int, int], r0)
r4 = PySequence_Contains(r3, r2)
r5 = r4 >= 0 :: signed
r6 = truncate r4: i32 to builtins.bool
return r6

[case testTupleBuiltFromList]
def f(val: int) -> bool:
return val % 2 == 0
Expand Down
Loading