diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 29820787d10c..2e92c05802e0 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Callable, Final, Optional +from typing import Callable, Final, Optional, cast from mypy.nodes import ( ARG_NAMED, @@ -40,6 +40,7 @@ Call, Extend, Integer, + PrimitiveDescription, RaiseStandardError, Register, Truncate, @@ -587,26 +588,81 @@ def translate_isinstance(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> if not (len(expr.args) == 2 and expr.arg_kinds == [ARG_POS, ARG_POS]): return None - if isinstance(expr.args[1], (RefExpr, TupleExpr)): - builder.types[expr.args[0]] = AnyType(TypeOfAny.from_error) + obj_expr = expr.args[0] + type_expr = expr.args[1] - irs = builder.flatten_classes(expr.args[1]) + if isinstance(type_expr, TupleExpr) and not type_expr.items: + # we can compile this case to a noop + return builder.false() + + if isinstance(type_expr, (RefExpr, TupleExpr)): + builder.types[obj_expr] = AnyType(TypeOfAny.from_error) + + irs = builder.flatten_classes(type_expr) if irs is not None: can_borrow = all( ir.is_ext_class and not ir.inherits_python and not ir.allow_interpreted_subclasses for ir in irs ) - obj = builder.accept(expr.args[0], can_borrow=can_borrow) + obj = builder.accept(obj_expr, can_borrow=can_borrow) return builder.builder.isinstance_helper(obj, irs, expr.line) - if isinstance(expr.args[1], RefExpr): - node = expr.args[1].node + if isinstance(type_expr, RefExpr): + node = type_expr.node if node: desc = isinstance_primitives.get(node.fullname) if desc: - obj = builder.accept(expr.args[0]) + obj = builder.accept(obj_expr) return builder.primitive_op(desc, [obj], expr.line) + elif isinstance(type_expr, TupleExpr): + node_names: list[str] = [] + for item in type_expr.items: + if not isinstance(item, RefExpr): + return None + if item.node is None: + return None + if item.node.fullname not in node_names: + node_names.append(item.node.fullname) + + descs = [isinstance_primitives.get(fullname) for fullname in node_names] + if None in descs: + # not all types are primitive types, abort + return None + + obj = builder.accept(obj_expr) + + retval = Register(bool_rprimitive) + pass_block = BasicBlock() + fail_block = BasicBlock() + exit_block = BasicBlock() + + # Chain the checks: if any succeed, jump to pass_block; else, continue + for i, desc in enumerate(descs): + is_last = i == len(descs) - 1 + next_block = fail_block if is_last else BasicBlock() + builder.add_bool_branch( + builder.primitive_op(cast(PrimitiveDescription, desc), [obj], expr.line), + pass_block, + next_block, + ) + if not is_last: + builder.activate_block(next_block) + + # If any check passed + builder.activate_block(pass_block) + builder.assign(retval, builder.true(), expr.line) + builder.goto(exit_block) + + # If all checks failed + builder.activate_block(fail_block) + builder.assign(retval, builder.false(), expr.line) + builder.goto(exit_block) + + # Return the result + builder.activate_block(exit_block) + return retval + return None diff --git a/mypyc/test-data/irbuild-isinstance.test b/mypyc/test-data/irbuild-isinstance.test index 0df9448b819f..36a9300350bd 100644 --- a/mypyc/test-data/irbuild-isinstance.test +++ b/mypyc/test-data/irbuild-isinstance.test @@ -189,3 +189,31 @@ def is_tuple(x): L0: r0 = PyTuple_Check(x) return r0 + +[case testTupleOfPrimitives] +from typing import Any + +def is_instance(x: Any) -> bool: + return isinstance(x, (str, int, bytes)) + +[out] +def is_instance(x): + x :: object + r0, r1, r2 :: bit + r3 :: bool +L0: + r0 = PyUnicode_Check(x) + if r0 goto L3 else goto L1 :: bool +L1: + r1 = PyLong_Check(x) + if r1 goto L3 else goto L2 :: bool +L2: + r2 = PyBytes_Check(x) + if r2 goto L3 else goto L4 :: bool +L3: + r3 = 1 + goto L5 +L4: + r3 = 0 +L5: + return r3 diff --git a/mypyc/test-data/run-misc.test b/mypyc/test-data/run-misc.test index 129946a4c330..1074906357ee 100644 --- a/mypyc/test-data/run-misc.test +++ b/mypyc/test-data/run-misc.test @@ -1173,3 +1173,26 @@ def test_dummy_context() -> None: with c: assert c.c == 1 assert c.c == 0 + +[case testIsInstanceTuple] +from typing import Any + +def isinstance_empty(x: Any) -> bool: + return isinstance(x, ()) +def isinstance_single(x: Any) -> bool: + return isinstance(x, (str,)) +def isinstance_multi(x: Any) -> bool: + return isinstance(x, (str, int)) + +def test_isinstance_empty() -> None: + assert isinstance_empty("a") is False + assert isinstance_empty(1) is False + assert isinstance_empty(None) is False +def test_isinstance_single() -> None: + assert isinstance_single("a") is True + assert isinstance_single(1) is False + assert isinstance_single(None) is False +def test_isinstance_multi() -> None: + assert isinstance_multi("a") is True + assert isinstance_multi(1) is True + assert isinstance_multi(None) is False