Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ class to enable the new behavior. Sometimes adding a new abstract
RStruct,
RTuple,
RType,
RUnion,
RVoid,
bit_rprimitive,
bool_rprimitive,
cstring_rprimitive,
float_rprimitive,
int_rprimitive,
is_bool_or_bit_rprimitive,
is_fixed_width_rtype,
is_int_rprimitive,
is_none_rprimitive,
is_pointer_rprimitive,
Expand Down Expand Up @@ -688,7 +690,7 @@ class PrimitiveDescription:
Primitives get lowered into lower-level ops before code generation.

If c_function_name is provided, a primitive will be lowered into a CallC op.
Otherwise custom logic will need to be implemented to transform the
Otherwise, custom logic will need to be implemented to transform the
primitive into lower-level ops.
"""

Expand Down Expand Up @@ -737,11 +739,24 @@ def __init__(
# Capsule that needs to imported and configured to call the primitive
# (name of the target module, e.g. "librt.base64").
self.capsule = capsule
# Native integer types such as u8 can cause ambiguity in primitive
# matching, since these are assignable to plain int *and* vice versa.
# If this flag is set, the primitive has native integer types and must
# be matched using more complex rules.
self.is_ambiguous = any(has_fixed_width_int(t) for t in arg_types)

def __repr__(self) -> str:
return f"<PrimitiveDescription {self.name!r}: {self.arg_types}>"


def has_fixed_width_int(t: RType) -> bool:
if isinstance(t, RTuple):
return any(has_fixed_width_int(t) for t in t.types)
elif isinstance(t, RUnion):
return any(has_fixed_width_int(t) for t in t.items)
return is_fixed_width_rtype(t)


@final
class PrimitiveOp(RegisterOp):
"""A higher-level primitive operation.
Expand Down
14 changes: 13 additions & 1 deletion mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2209,8 +2209,14 @@ def matching_primitive_op(
args: list[Value],
line: int,
result_type: RType | None = None,
*,
can_borrow: bool = False,
strict: bool = True,
) -> Value | None:
"""Find primitive operation that is compatible with types of args.

Return None if none of them match.
"""
matching: PrimitiveDescription | None = None
for desc in candidates:
if len(desc.arg_types) != len(args):
Expand All @@ -2219,7 +2225,7 @@ def matching_primitive_op(
continue
if all(
# formal is not None and # TODO
is_subtype(actual.type, formal)
is_subtype(actual.type, formal, relaxed=not strict)
for actual, formal in zip(args, desc.arg_types)
) and (not desc.is_borrowed or can_borrow):
if matching:
Expand All @@ -2232,6 +2238,12 @@ def matching_primitive_op(
matching = desc
if matching:
return self.primitive_op(matching, args, line=line, result_type=result_type)
if strict and any(prim.is_ambiguous for prim in candidates):
# Also try a non-exact match if any primitives have ambiguous types.
return self.matching_primitive_op(
candidates, args, line, result_type, can_borrow=can_borrow, strict=False
)

return None

def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int = -1) -> Value:
Expand Down
24 changes: 17 additions & 7 deletions mypyc/subtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,21 @@
)


def is_subtype(left: RType, right: RType) -> bool:
def is_subtype(left: RType, right: RType, *, relaxed: bool = False) -> bool:
if is_object_rprimitive(right):
return True
elif isinstance(right, RUnion):
if isinstance(left, RUnion):
for left_item in left.items:
if not any(is_subtype(left_item, right_item) for right_item in right.items):
if not any(
is_subtype(left_item, right_item, relaxed=relaxed)
for right_item in right.items
):
return False
return True
else:
return any(is_subtype(left, item) for item in right.items)
return left.accept(SubtypeVisitor(right))
return any(is_subtype(left, item, relaxed=relaxed) for item in right.items)
return left.accept(SubtypeVisitor(right, relaxed=relaxed))


class SubtypeVisitor(RTypeVisitor[bool]):
Expand All @@ -44,14 +47,15 @@ class SubtypeVisitor(RTypeVisitor[bool]):
is_subtype and don't need to be covered here.
"""

def __init__(self, right: RType) -> None:
def __init__(self, right: RType, *, relaxed: bool = False) -> None:
self.right = right
self.relaxed = relaxed

def visit_rinstance(self, left: RInstance) -> bool:
return isinstance(self.right, RInstance) and self.right.class_ir in left.class_ir.mro

def visit_runion(self, left: RUnion) -> bool:
return all(is_subtype(item, self.right) for item in left.items)
return all(is_subtype(item, self.right, relaxed=self.relaxed) for item in left.items)

def visit_rprimitive(self, left: RPrimitive) -> bool:
right = self.right
Expand All @@ -64,6 +68,11 @@ def visit_rprimitive(self, left: RPrimitive) -> bool:
elif is_short_int_rprimitive(left):
if is_int_rprimitive(right):
return True
if self.relaxed and is_fixed_width_rtype(right):
return True
elif is_int_rprimitive(left):
if self.relaxed and is_fixed_width_rtype(right):
return True
elif is_fixed_width_rtype(left):
if is_int_rprimitive(right):
return True
Expand All @@ -74,7 +83,8 @@ def visit_rtuple(self, left: RTuple) -> bool:
return True
if isinstance(self.right, RTuple):
return len(self.right.types) == len(left.types) and all(
is_subtype(t1, t2) for t1, t2 in zip(left.types, self.right.types)
is_subtype(t1, t2, relaxed=self.relaxed)
for t1, t2 in zip(left.types, self.right.types)
)
return False

Expand Down
45 changes: 31 additions & 14 deletions mypyc/test-data/irbuild-librt-strings.test
Original file line number Diff line number Diff line change
@@ -1,38 +1,55 @@
[case testLibrtStrings_experimental]
[case testLibrtStrings_experimental_64bit]
from librt.strings import BytesWriter
from mypy_extensions import u8, i64

def bytes_writer_basics() -> bytes:
b = BytesWriter()
x: u8 = 1
b.append(x)
b.append(1)
b.write(b'foo')
n: i64 = 2
n = 2
b.truncate(n)
return b.getvalue()
def bytes_writer_len(b: BytesWriter) -> i64:
return len(b)
[out]
def bytes_writer_basics():
r0, b :: librt.strings.BytesWriter
x :: u8
r1 :: None
r2 :: bytes
r3 :: None
n :: i64
r4 :: None
r5 :: bytes
n :: int
r4 :: native_int
r5 :: bit
r6, r7 :: i64
r8 :: ptr
r9 :: c_ptr
r10 :: i64
r11 :: None
r12 :: bytes
L0:
r0 = LibRTStrings_BytesWriter_internal()
b = r0
x = 1
r1 = LibRTStrings_BytesWriter_append_internal(b, x)
r1 = LibRTStrings_BytesWriter_append_internal(b, 1)
r2 = b'foo'
r3 = LibRTStrings_BytesWriter_write_internal(b, r2)
n = 2
r4 = LibRTStrings_BytesWriter_truncate_internal(b, n)
r5 = LibRTStrings_BytesWriter_getvalue_internal(b)
return r5
n = 4
r4 = n & 1
r5 = r4 == 0
if r5 goto L1 else goto L2 :: bool
L1:
r6 = n >> 1
r7 = r6
goto L3
L2:
r8 = n ^ 1
r9 = r8
r10 = CPyLong_AsInt64(r9)
r7 = r10
keep_alive n
L3:
r11 = LibRTStrings_BytesWriter_truncate_internal(b, r7)
r12 = LibRTStrings_BytesWriter_getvalue_internal(b)
return r12
def bytes_writer_len(b):
b :: librt.strings.BytesWriter
r0 :: short_int
Expand Down