diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 20440d4a26f4..82ce782646d7 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -16,6 +16,7 @@ DictionaryComprehension, Expression, GeneratorExpr, + IntExpr, ListExpr, Lvalue, MemberExpr, @@ -239,38 +240,58 @@ def sequence_from_generator_preallocate_helper( line = gen.line sequence_expr = gen.sequences[0] rtype = builder.node_type(sequence_expr) - if not (is_sequence_rprimitive(rtype) or isinstance(rtype, RTuple)): - return None - sequence = builder.accept(sequence_expr) - length = get_expr_length_value(builder, sequence_expr, sequence, line, use_pyssize_t=True) - if isinstance(rtype, RTuple): - # If input is RTuple, box it to tuple_rprimitive for generic iteration - # TODO: this can be optimized a bit better with an unrolled ForRTuple helper - proper_type = get_proper_type(builder.types[sequence_expr]) - assert isinstance(proper_type, TupleType), proper_type - - get_item_ops = [ - ( - LoadLiteral(typ.value, object_rprimitive) - if isinstance(typ, LiteralType) - else TupleGet(sequence, i, line) - ) - for i, typ in enumerate(get_proper_types(proper_type.items)) - ] - items = list(map(builder.add, get_item_ops)) - sequence = builder.new_tuple(items, line) + if is_sequence_rprimitive(rtype) or isinstance(rtype, RTuple): + sequence = builder.accept(sequence_expr) + length = get_expr_length_value( + builder, sequence_expr, sequence, line, use_pyssize_t=True + ) + if isinstance(rtype, RTuple): + # If input is RTuple, box it to tuple_rprimitive for generic iteration + # TODO: this can be optimized a bit better with an unrolled ForRTuple helper + proper_type = get_proper_type(builder.types[sequence_expr]) + assert isinstance(proper_type, TupleType), proper_type + + get_item_ops = [ + ( + LoadLiteral(typ.value, object_rprimitive) + if isinstance(typ, LiteralType) + else TupleGet(sequence, i, line) + ) + for i, typ in enumerate(get_proper_types(proper_type.items)) + ] + items = list(map(builder.add, get_item_ops)) + sequence = builder.new_tuple(items, line) + + target_op = empty_op_llbuilder(length, line) + + def set_item(item_index: Value) -> None: + e = builder.accept(gen.left_expr) + builder.call_c(set_item_op, [target_op, item_index, e], line) + + for_loop_helper_with_index( + builder, gen.indices[0], sequence_expr, sequence, set_item, line, length + ) - target_op = empty_op_llbuilder(length, line) + return target_op - def set_item(item_index: Value) -> None: - e = builder.accept(gen.left_expr) - builder.call_c(set_item_op, [target_op, item_index, e], line) + elif (expr_length := get_expr_length(sequence_expr)) is not None: + item_index = Register(int_rprimitive) + builder.assign(item_index, Integer(0), line) - for_loop_helper_with_index( - builder, gen.indices[0], sequence_expr, sequence, set_item, line, length - ) + def set_item_noindex() -> None: + e = builder.accept(gen.left_expr) + builder.call_c(set_item_op, [target_op, item_index, e], line) + builder.assign( + item_index, builder.binary_op(item_index, Integer(1), "+", line), line + ) + + length = Integer(expr_length, c_pyssize_t_rprimitive, line) + target_op = empty_op_llbuilder(length, line) + for_loop_helper( + builder, gen.indices[0], sequence_expr, set_item_noindex, None, False, line + ) + return target_op - return target_op return None @@ -1223,10 +1244,37 @@ def get_expr_length(expr: Expression) -> int | None: and expr.node.has_explicit_value ): return len(expr.node.final_value) + elif ( + isinstance(expr, CallExpr) + and isinstance(callee := expr.callee, NameExpr) + and all(kind == ARG_POS for kind in expr.arg_kinds) + ): + fullname = callee.fullname + if ( + fullname + in ( + "builtins.list", + "builtins.tuple", + "builtins.enumerate", + "builtins.sorted", + "builtins.reversed", + ) + and len(expr.args) == 1 + ): + return get_expr_length(expr.args[0]) + elif fullname == "builtins.map" and len(expr.args) == 2: + return get_expr_length(expr.args[1]) + elif fullname == "builtins.zip" and expr.args: + arg_lengths = [get_expr_length(arg) for arg in expr.args] + if all(arg is not None for arg in arg_lengths): + return min(arg_lengths) # type: ignore [type-var] + elif fullname == "builtins.range" and all(isinstance(arg, IntExpr) for arg in expr.args): + return len(range(*(arg.value for arg in expr.args))) # type: ignore [attr-defined] + # TODO: extend this, passing length of listcomp and genexp should have worthwhile # performance boost and can be (sometimes) figured out pretty easily. set and dict # comps *can* be done as well but will need special logic to consider the possibility - # of key conflicts. Range, enumerate, zip are all simple logic. + # of key conflicts. return None diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 4d0aaba12cab..b62ac6b74f9e 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -310,6 +310,11 @@ def __iter__(self) -> Iterator[int]: pass def __len__(self) -> int: pass def __next__(self) -> int: pass +class map(Iterator[_S]): + def __init__(self, func: Callable[[_T], _S], iterable: Iterable[_T]) -> None: pass + def __iter__(self) -> Self: pass + def __next__(self) -> _S: pass + class property: def __init__(self, fget: Optional[Callable[[Any], Any]] = ..., fset: Optional[Callable[[Any, Any], None]] = ..., diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index 3613c5f0101d..f2c36e5aa012 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -860,6 +860,149 @@ L4: a = r1 return 1 +[case testTupleBuiltFromLengthCheckable_64bit] +from typing import Tuple + +def f(val: bool) -> bool: + return not val + +def test() -> None: + # this tuple is created from a very complex genexp but we can still compute the length and preallocate the tuple + # r1 = PyTuple_New(5) the shorter input to the zip(...) has len 5 + a = tuple( + x + for x + in zip( + map(str, range(5)), + enumerate(sorted(reversed(tuple("abcdefg")))) + ) + ) +[out] +def f(val): + val, r0 :: bool +L0: + r0 = val ^ 1 + return r0 +def test(): + r0 :: int + r1 :: tuple + r2, r3, r4 :: object + r5 :: object[1] + r6 :: object_ptr + r7 :: object + r8 :: range + r9 :: object + r10 :: str + r11 :: object + r12 :: object[2] + r13 :: object_ptr + r14 :: object + r15 :: str + r16 :: tuple + r17 :: object + r18 :: str + r19 :: object + r20 :: object[1] + r21 :: object_ptr + r22 :: object + r23 :: list + r24 :: object + r25 :: str + r26 :: object + r27 :: object[1] + r28 :: object_ptr + r29, r30 :: object + r31 :: str + r32 :: object + r33 :: object[2] + r34 :: object_ptr + r35, r36, r37 :: object + r38, x :: tuple[str, tuple[int, str]] + r39 :: native_int + r40 :: bit + r41, r42 :: native_int + r43 :: ptr + r44 :: c_ptr + r45 :: i64 + r46 :: object + r47 :: int + r48 :: bit + a :: tuple +L0: + r0 = 0 + r1 = PyTuple_New(5) + r2 = load_address PyUnicode_Type + r3 = load_address PyRange_Type + r4 = object 5 + r5 = [r4] + r6 = load_address r5 + r7 = PyObject_Vectorcall(r3, r6, 1, 0) + keep_alive r4 + r8 = cast(range, r7) + r9 = builtins :: module + r10 = 'map' + r11 = CPyObject_GetAttr(r9, r10) + r12 = [r2, r8] + r13 = load_address r12 + r14 = PyObject_Vectorcall(r11, r13, 2, 0) + keep_alive r2, r8 + r15 = 'abcdefg' + r16 = PySequence_Tuple(r15) + r17 = builtins :: module + r18 = 'reversed' + r19 = CPyObject_GetAttr(r17, r18) + r20 = [r16] + r21 = load_address r20 + r22 = PyObject_Vectorcall(r19, r21, 1, 0) + keep_alive r16 + r23 = CPySequence_Sort(r22) + r24 = builtins :: module + r25 = 'enumerate' + r26 = CPyObject_GetAttr(r24, r25) + r27 = [r23] + r28 = load_address r27 + r29 = PyObject_Vectorcall(r26, r28, 1, 0) + keep_alive r23 + r30 = builtins :: module + r31 = 'zip' + r32 = CPyObject_GetAttr(r30, r31) + r33 = [r14, r29] + r34 = load_address r33 + r35 = PyObject_Vectorcall(r32, r34, 2, 0) + keep_alive r14, r29 + r36 = PyObject_GetIter(r35) +L1: + r37 = PyIter_Next(r36) + if is_error(r37) goto L7 else goto L2 +L2: + r38 = unbox(tuple[str, tuple[int, str]], r37) + x = r38 + r39 = r0 & 1 + r40 = r39 == 0 + if r40 goto L3 else goto L4 :: bool +L3: + r41 = r0 >> 1 + r42 = r41 + goto L5 +L4: + r43 = r0 ^ 1 + r44 = r43 + r45 = CPyLong_AsInt64(r44) + r42 = r45 + keep_alive r0 +L5: + r46 = box(tuple[str, tuple[int, str]], x) + CPySequenceTuple_SetItemUnsafe(r1, r42, r46) + r47 = CPyTagged_Add(r0, 2) + r0 = r47 +L6: + goto L1 +L7: + r48 = CPy_NoErrOccurred() +L8: + a = r1 + return 1 + [case testTupleBuiltFromStars] from typing import Final