diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 2153d47e6874..072892f776df 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -37,6 +37,7 @@ class to enable the new behavior. Sometimes adding a new abstract RStruct, RTuple, RType, + RUnion, RVoid, bit_rprimitive, bool_rprimitive, @@ -44,6 +45,7 @@ class to enable the new behavior. Sometimes adding a new abstract float_rprimitive, int_rprimitive, is_bool_or_bit_rprimitive, + is_fixed_width_rtype, is_int_rprimitive, is_none_rprimitive, is_pointer_rprimitive, @@ -688,7 +690,7 @@ class PrimitiveDescription: Primitives get lowered into lower-level ops before code generation. If c_function_name is provided, a primitive will be lowered into a CallC op. - Otherwise custom logic will need to be implemented to transform the + Otherwise, custom logic will need to be implemented to transform the primitive into lower-level ops. """ @@ -737,11 +739,24 @@ def __init__( # Capsule that needs to imported and configured to call the primitive # (name of the target module, e.g. "librt.base64"). self.capsule = capsule + # Native integer types such as u8 can cause ambiguity in primitive + # matching, since these are assignable to plain int *and* vice versa. + # If this flag is set, the primitive has native integer types and must + # be matched using more complex rules. + self.is_ambiguous = any(has_fixed_width_int(t) for t in arg_types) def __repr__(self) -> str: return f"" +def has_fixed_width_int(t: RType) -> bool: + if isinstance(t, RTuple): + return any(has_fixed_width_int(t) for t in t.types) + elif isinstance(t, RUnion): + return any(has_fixed_width_int(t) for t in t.items) + return is_fixed_width_rtype(t) + + @final class PrimitiveOp(RegisterOp): """A higher-level primitive operation. diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 3949e585aefb..6baff44f2f82 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -2209,8 +2209,14 @@ def matching_primitive_op( args: list[Value], line: int, result_type: RType | None = None, + *, can_borrow: bool = False, + strict: bool = True, ) -> Value | None: + """Find primitive operation that is compatible with types of args. + + Return None if none of them match. + """ matching: PrimitiveDescription | None = None for desc in candidates: if len(desc.arg_types) != len(args): @@ -2219,7 +2225,7 @@ def matching_primitive_op( continue if all( # formal is not None and # TODO - is_subtype(actual.type, formal) + is_subtype(actual.type, formal, relaxed=not strict) for actual, formal in zip(args, desc.arg_types) ) and (not desc.is_borrowed or can_borrow): if matching: @@ -2232,6 +2238,12 @@ def matching_primitive_op( matching = desc if matching: return self.primitive_op(matching, args, line=line, result_type=result_type) + if strict and any(prim.is_ambiguous for prim in candidates): + # Also try a non-exact match if any primitives have ambiguous types. + return self.matching_primitive_op( + candidates, args, line, result_type, can_borrow=can_borrow, strict=False + ) + return None def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int = -1) -> Value: diff --git a/mypyc/subtype.py b/mypyc/subtype.py index 726a48d7a01d..3d5fa1fb4fca 100644 --- a/mypyc/subtype.py +++ b/mypyc/subtype.py @@ -23,18 +23,21 @@ ) -def is_subtype(left: RType, right: RType) -> bool: +def is_subtype(left: RType, right: RType, *, relaxed: bool = False) -> bool: if is_object_rprimitive(right): return True elif isinstance(right, RUnion): if isinstance(left, RUnion): for left_item in left.items: - if not any(is_subtype(left_item, right_item) for right_item in right.items): + if not any( + is_subtype(left_item, right_item, relaxed=relaxed) + for right_item in right.items + ): return False return True else: - return any(is_subtype(left, item) for item in right.items) - return left.accept(SubtypeVisitor(right)) + return any(is_subtype(left, item, relaxed=relaxed) for item in right.items) + return left.accept(SubtypeVisitor(right, relaxed=relaxed)) class SubtypeVisitor(RTypeVisitor[bool]): @@ -44,14 +47,15 @@ class SubtypeVisitor(RTypeVisitor[bool]): is_subtype and don't need to be covered here. """ - def __init__(self, right: RType) -> None: + def __init__(self, right: RType, *, relaxed: bool = False) -> None: self.right = right + self.relaxed = relaxed def visit_rinstance(self, left: RInstance) -> bool: return isinstance(self.right, RInstance) and self.right.class_ir in left.class_ir.mro def visit_runion(self, left: RUnion) -> bool: - return all(is_subtype(item, self.right) for item in left.items) + return all(is_subtype(item, self.right, relaxed=self.relaxed) for item in left.items) def visit_rprimitive(self, left: RPrimitive) -> bool: right = self.right @@ -64,6 +68,11 @@ def visit_rprimitive(self, left: RPrimitive) -> bool: elif is_short_int_rprimitive(left): if is_int_rprimitive(right): return True + if self.relaxed and is_fixed_width_rtype(right): + return True + elif is_int_rprimitive(left): + if self.relaxed and is_fixed_width_rtype(right): + return True elif is_fixed_width_rtype(left): if is_int_rprimitive(right): return True @@ -74,7 +83,8 @@ def visit_rtuple(self, left: RTuple) -> bool: return True if isinstance(self.right, RTuple): return len(self.right.types) == len(left.types) and all( - is_subtype(t1, t2) for t1, t2 in zip(left.types, self.right.types) + is_subtype(t1, t2, relaxed=self.relaxed) + for t1, t2 in zip(left.types, self.right.types) ) return False diff --git a/mypyc/test-data/irbuild-librt-strings.test b/mypyc/test-data/irbuild-librt-strings.test index 133bca8c7cb3..21aba382cdd8 100644 --- a/mypyc/test-data/irbuild-librt-strings.test +++ b/mypyc/test-data/irbuild-librt-strings.test @@ -1,13 +1,12 @@ -[case testLibrtStrings_experimental] +[case testLibrtStrings_experimental_64bit] from librt.strings import BytesWriter from mypy_extensions import u8, i64 def bytes_writer_basics() -> bytes: b = BytesWriter() - x: u8 = 1 - b.append(x) + b.append(1) b.write(b'foo') - n: i64 = 2 + n = 2 b.truncate(n) return b.getvalue() def bytes_writer_len(b: BytesWriter) -> i64: @@ -15,24 +14,42 @@ def bytes_writer_len(b: BytesWriter) -> i64: [out] def bytes_writer_basics(): r0, b :: librt.strings.BytesWriter - x :: u8 r1 :: None r2 :: bytes r3 :: None - n :: i64 - r4 :: None - r5 :: bytes + n :: int + r4 :: native_int + r5 :: bit + r6, r7 :: i64 + r8 :: ptr + r9 :: c_ptr + r10 :: i64 + r11 :: None + r12 :: bytes L0: r0 = LibRTStrings_BytesWriter_internal() b = r0 - x = 1 - r1 = LibRTStrings_BytesWriter_append_internal(b, x) + r1 = LibRTStrings_BytesWriter_append_internal(b, 1) r2 = b'foo' r3 = LibRTStrings_BytesWriter_write_internal(b, r2) - n = 2 - r4 = LibRTStrings_BytesWriter_truncate_internal(b, n) - r5 = LibRTStrings_BytesWriter_getvalue_internal(b) - return r5 + n = 4 + r4 = n & 1 + r5 = r4 == 0 + if r5 goto L1 else goto L2 :: bool +L1: + r6 = n >> 1 + r7 = r6 + goto L3 +L2: + r8 = n ^ 1 + r9 = r8 + r10 = CPyLong_AsInt64(r9) + r7 = r10 + keep_alive n +L3: + r11 = LibRTStrings_BytesWriter_truncate_internal(b, r7) + r12 = LibRTStrings_BytesWriter_getvalue_internal(b) + return r12 def bytes_writer_len(b): b :: librt.strings.BytesWriter r0 :: short_int