diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index b141784ef9ff2..e96c640fa8a15 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -252,6 +252,15 @@ def check_tuple_items_valid_literals(self, op: LoadLiteral, t: tuple[object, ... if isinstance(x, tuple): self.check_tuple_items_valid_literals(op, x) + def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[object]) -> None: + for x in s: + if x is None or isinstance(x, (str, bytes, bool, int, float, complex)): + pass + elif isinstance(x, tuple): + self.check_tuple_items_valid_literals(op, x) + else: + self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})") + def visit_load_literal(self, op: LoadLiteral) -> None: expected_type = None if op.value is None: @@ -271,6 +280,11 @@ def visit_load_literal(self, op: LoadLiteral) -> None: elif isinstance(op.value, tuple): expected_type = "builtins.tuple" self.check_tuple_items_valid_literals(op, op.value) + elif isinstance(op.value, frozenset): + # There's no frozenset_rprimitive type since it'd be pretty useless so we just pretend + # it's a set (when it's really a frozenset). + expected_type = "builtins.set" + self.check_frozenset_items_valid_literals(op, op.value) assert expected_type is not None, "Missed a case for LoadLiteral check" diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 5dacaf6acab6c..9f65aa77c47f4 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -669,6 +669,9 @@ def generate_literal_tables(self) -> None: # Descriptions of tuple literals init_tuple = c_array_initializer(literals.encoded_tuple_values()) self.declare_global("const int []", "CPyLit_Tuple", initializer=init_tuple) + # Descriptions of frozenset literals + init_frozenset = c_array_initializer(literals.encoded_frozenset_values()) + self.declare_global("const int []", "CPyLit_FrozenSet", initializer=init_frozenset) def generate_export_table(self, decl_emitter: Emitter, code_emitter: Emitter) -> None: """Generate the declaration and definition of the group's export struct. @@ -839,7 +842,7 @@ def generate_globals_init(self, emitter: Emitter) -> None: for symbol, fixup in self.simple_inits: emitter.emit_line(f"{symbol} = {fixup};") - values = "CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple" + values = "CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple, CPyLit_FrozenSet" emitter.emit_lines( f"if (CPyStatics_Initialize(CPyStatics, {values}) < 0) {{", "return -1;", "}" ) diff --git a/mypyc/codegen/literals.py b/mypyc/codegen/literals.py index 29957d52101c4..784a8ed27c4e4 100644 --- a/mypyc/codegen/literals.py +++ b/mypyc/codegen/literals.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import Any, Tuple, Union, cast +from typing import Any, Dict, FrozenSet, List, Tuple, Union, cast from typing_extensions import Final -# Supported Python literal types. All tuple items must have supported +# Supported Python literal types. All tuple / frozenset items must have supported # literal types as well, but we can't represent the type precisely. -LiteralValue = Union[str, bytes, int, bool, float, complex, Tuple[object, ...], None] - +LiteralValue = Union[ + str, bytes, int, bool, float, complex, Tuple[object, ...], FrozenSet[object], None +] # Some literals are singletons and handled specially (None, False and True) NUM_SINGLETONS: Final = 3 @@ -23,6 +24,7 @@ def __init__(self) -> None: self.float_literals: dict[float, int] = {} self.complex_literals: dict[complex, int] = {} self.tuple_literals: dict[tuple[object, ...], int] = {} + self.frozenset_literals: dict[frozenset[object], int] = {} def record_literal(self, value: LiteralValue) -> None: """Ensure that the literal value is available in generated code.""" @@ -55,6 +57,12 @@ def record_literal(self, value: LiteralValue) -> None: for item in value: self.record_literal(cast(Any, item)) tuple_literals[value] = len(tuple_literals) + elif isinstance(value, frozenset): + frozenset_literals = self.frozenset_literals + if value not in frozenset_literals: + for item in value: + self.record_literal(cast(Any, item)) + frozenset_literals[value] = len(frozenset_literals) else: assert False, "invalid literal: %r" % value @@ -86,6 +94,9 @@ def literal_index(self, value: LiteralValue) -> int: n += len(self.complex_literals) if isinstance(value, tuple): return n + self.tuple_literals[value] + n += len(self.tuple_literals) + if isinstance(value, frozenset): + return n + self.frozenset_literals[value] assert False, "invalid literal: %r" % value def num_literals(self) -> int: @@ -98,6 +109,7 @@ def num_literals(self) -> int: + len(self.float_literals) + len(self.complex_literals) + len(self.tuple_literals) + + len(self.frozenset_literals) ) # The following methods return the C encodings of literal values @@ -119,24 +131,32 @@ def encoded_complex_values(self) -> list[str]: return _encode_complex_values(self.complex_literals) def encoded_tuple_values(self) -> list[str]: - """Encode tuple values into a C array. + return self._encode_collection_values(self.tuple_literals) + + def encoded_frozenset_values(self) -> List[str]: + return self._encode_collection_values(self.frozenset_literals) + + def _encode_collection_values( + self, values: dict[tuple[object, ...], int] | dict[frozenset[object], int] + ) -> list[str]: + """Encode tuple/frozenset values into a C array. The format of the result is like this: - - + + ... - + ... """ - values = self.tuple_literals - value_by_index = {index: value for value, index in values.items()} + # FIXME: https://github.com/mypyc/mypyc/issues/965 + value_by_index = {index: value for value, index in cast(Dict[Any, int], values).items()} result = [] - num = len(values) - result.append(str(num)) - for i in range(num): + count = len(values) + result.append(str(count)) + for i in range(count): value = value_by_index[i] result.append(str(len(value))) for item in value: diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 1f79ba829d767..cc6f542c3e234 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -39,6 +39,7 @@ ) if TYPE_CHECKING: + from mypyc.codegen.literals import LiteralValue from mypyc.ir.class_ir import ClassIR from mypyc.ir.func_ir import FuncDecl, FuncIR @@ -588,7 +589,7 @@ class LoadLiteral(RegisterOp): This is used to load a static PyObject * value corresponding to a literal of one of the supported types. - Tuple literals must contain only valid literal values as items. + Tuple / frozenset literals must contain only valid literal values as items. NOTE: You can use this to load boxed (Python) int objects. Use Integer to load unboxed, tagged integers or fixed-width, @@ -603,11 +604,7 @@ class LoadLiteral(RegisterOp): error_kind = ERR_NEVER is_borrowed = True - def __init__( - self, - value: None | str | bytes | bool | int | float | complex | tuple[object, ...], - rtype: RType, - ) -> None: + def __init__(self, value: LiteralValue, rtype: RType) -> None: self.value = value self.type = rtype diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index a9324a8608e40..cb9e4a2d25418 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -106,7 +106,18 @@ def visit_load_literal(self, op: LoadLiteral) -> str: # it explicit that this is a Python object. if isinstance(op.value, int): prefix = "object " - return self.format("%r = %s%s", op, prefix, repr(op.value)) + + rvalue = repr(op.value) + if isinstance(op.value, frozenset): + # We need to generate a string representation that won't vary + # run-to-run because sets are unordered, otherwise we may get + # spurious irbuild test failures. + # + # Sorting by the item's string representation is a bit of a + # hack, but it's stable and won't cause TypeErrors. + formatted_items = [repr(i) for i in sorted(op.value, key=str)] + rvalue = "frozenset({" + ", ".join(formatted_items) + "})" + return self.format("%r = %s%s", op, prefix, rvalue) def visit_get_attr(self, op: GetAttr) -> str: return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 792697970785a..c24207ac64ecb 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -118,7 +118,7 @@ AssignmentTargetRegister, AssignmentTargetTuple, ) -from mypyc.irbuild.util import is_constant +from mypyc.irbuild.util import bytes_from_str, is_constant from mypyc.options import CompilerOptions from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op @@ -296,8 +296,7 @@ def load_bytes_from_str_literal(self, value: str) -> Value: are stored in BytesExpr.value, whose type is 'str' not 'bytes'. Thus we perform a special conversion here. """ - bytes_value = bytes(value, "utf8").decode("unicode-escape").encode("raw-unicode-escape") - return self.builder.load_bytes(bytes_value) + return self.builder.load_bytes(bytes_from_str(value)) def load_int(self, value: int) -> Value: return self.builder.load_int(value) @@ -886,7 +885,7 @@ def get_dict_base_type(self, expr: Expression) -> Instance: This is useful for dict subclasses like SymbolTable. """ target_type = get_proper_type(self.types[expr]) - assert isinstance(target_type, Instance) + assert isinstance(target_type, Instance), target_type dict_base = next(base for base in target_type.type.mro if base.fullname == "builtins.dict") return map_instance_to_supertype(target_type, dict_base) diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index b007435957b00..3f5b795a14363 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Callable, cast +from typing import Callable, Sequence, cast from mypy.nodes import ( ARG_POS, @@ -55,6 +55,7 @@ ComparisonOp, Integer, LoadAddress, + LoadLiteral, RaiseStandardError, Register, TupleGet, @@ -63,12 +64,14 @@ ) from mypyc.ir.rtypes import ( RTuple, + bool_rprimitive, int_rprimitive, is_fixed_width_rtype, is_int_rprimitive, is_list_rprimitive, is_none_rprimitive, object_rprimitive, + set_rprimitive, ) from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op @@ -86,6 +89,7 @@ tokenizer_printf_style, ) from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization +from mypyc.irbuild.util import bytes_from_str from mypyc.primitives.bytes_ops import bytes_slice_op from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op from mypyc.primitives.generic_ops import iter_op @@ -93,7 +97,7 @@ from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op from mypyc.primitives.registry import CFunctionDescription, builtin_names -from mypyc.primitives.set_ops import set_add_op, set_update_op +from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op from mypyc.primitives.str_ops import str_slice_op from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op @@ -613,6 +617,54 @@ def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Val return target +def set_literal_values(builder: IRBuilder, items: Sequence[Expression]) -> list[object] | None: + values: list[object] = [] + for item in items: + const_value = constant_fold_expr(builder, item) + if const_value is not None: + values.append(const_value) + continue + + if isinstance(item, RefExpr): + if item.fullname == "builtins.None": + values.append(None) + elif item.fullname == "builtins.True": + values.append(True) + elif item.fullname == "builtins.False": + values.append(False) + elif isinstance(item, (BytesExpr, FloatExpr, ComplexExpr)): + # constant_fold_expr() doesn't handle these (yet?) + v = bytes_from_str(item.value) if isinstance(item, BytesExpr) else item.value + values.append(v) + elif isinstance(item, TupleExpr): + tuple_values = set_literal_values(builder, item.items) + if tuple_values is not None: + values.append(tuple(tuple_values)) + + if len(values) != len(items): + # Bail if not all items can be converted into values. + return None + return values + + +def precompute_set_literal(builder: IRBuilder, s: SetExpr) -> Value | None: + """Try to pre-compute a frozenset literal during module initialization. + + Return None if it's not possible. + + Supported items: + - Anything supported by irbuild.constant_fold.constant_fold_expr() + - None, True, and False + - Float, byte, and complex literals + - Tuple literals with only items listed above + """ + values = set_literal_values(builder, s.items) + if values is not None: + return builder.add(LoadLiteral(frozenset(values), set_rprimitive)) + + return None + + def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: # x in (...)/[...] # x not in (...)/[...] @@ -666,6 +718,23 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: else: return builder.true() + # x in {...} + # x not in {...} + if ( + first_op in ("in", "not in") + and len(e.operators) == 1 + and isinstance(e.operands[1], SetExpr) + ): + set_literal = precompute_set_literal(builder, e.operands[1]) + if set_literal is not None: + lhs = e.operands[0] + result = builder.builder.call_c( + set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive + ) + if first_op == "not in": + return builder.unary_op(result, "not", e.line) + return result + if len(e.operators) == 1: # Special some common simple cases if first_op in ("is", "is not"): diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index fc67178af5de6..61dbbe960eb23 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -17,6 +17,7 @@ Lvalue, MemberExpr, RefExpr, + SetExpr, TupleExpr, TypeAlias, ) @@ -469,12 +470,22 @@ def make_for_loop_generator( for_dict_gen.init(expr_reg, target_type) return for_dict_gen + iterable_expr_reg: Value | None = None + if isinstance(expr, SetExpr): + # Special case "for x in ". + from mypyc.irbuild.expression import precompute_set_literal + + set_literal = precompute_set_literal(builder, expr) + if set_literal is not None: + iterable_expr_reg = set_literal + # Default to a generic for loop. - expr_reg = builder.accept(expr) + if iterable_expr_reg is None: + iterable_expr_reg = builder.accept(expr) for_obj = ForIterable(builder, index, body_block, loop_exit, line, nested) item_type = builder._analyze_iterable_item_type(expr) item_rtype = builder.type_to_rtype(item_type) - for_obj.init(expr_reg, item_rtype) + for_obj.init(iterable_expr_reg, item_rtype) return for_obj diff --git a/mypyc/irbuild/util.py b/mypyc/irbuild/util.py index f50241b96cb3e..ed01a59d1214e 100644 --- a/mypyc/irbuild/util.py +++ b/mypyc/irbuild/util.py @@ -177,3 +177,13 @@ def is_constant(e: Expression) -> bool: ) ) ) + + +def bytes_from_str(value: str) -> bytes: + """Convert a string representing bytes into actual bytes. + + This is needed because the literal characters of BytesExpr (the + characters inside b'') are stored in BytesExpr.value, whose type is + 'str' not 'bytes'. + """ + return bytes(value, "utf8").decode("unicode-escape").encode("raw-unicode-escape") diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 7ee914a037dc4..befa397051efb 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -597,7 +597,8 @@ int CPyStatics_Initialize(PyObject **statics, const char * const *ints, const double *floats, const double *complex_numbers, - const int *tuples); + const int *tuples, + const int *frozensets); PyObject *CPy_Super(PyObject *builtins, PyObject *self); PyObject *CPy_CallReverseOpMethod(PyObject *left, PyObject *right, const char *op, _Py_Identifier *method); diff --git a/mypyc/lib-rt/misc_ops.c b/mypyc/lib-rt/misc_ops.c index 25f33c5f56c79..5fda78704bbc1 100644 --- a/mypyc/lib-rt/misc_ops.c +++ b/mypyc/lib-rt/misc_ops.c @@ -535,7 +535,8 @@ int CPyStatics_Initialize(PyObject **statics, const char * const *ints, const double *floats, const double *complex_numbers, - const int *tuples) { + const int *tuples, + const int *frozensets) { PyObject **result = statics; // Start with some hard-coded values *result++ = Py_None; @@ -635,6 +636,24 @@ int CPyStatics_Initialize(PyObject **statics, *result++ = obj; } } + if (frozensets) { + int num = *frozensets++; + while (num-- > 0) { + int num_items = *frozensets++; + PyObject *obj = PyFrozenSet_New(NULL); + if (obj == NULL) { + return -1; + } + for (int i = 0; i < num_items; i++) { + PyObject *item = statics[*frozensets++]; + Py_INCREF(item); + if (PySet_Add(obj, item) == -1) { + return -1; + } + } + *result++ = obj; + } + } return 0; } diff --git a/mypyc/primitives/set_ops.py b/mypyc/primitives/set_ops.py index 801fdad34ea4d..fcfb7847dc7d5 100644 --- a/mypyc/primitives/set_ops.py +++ b/mypyc/primitives/set_ops.py @@ -54,7 +54,7 @@ ) # item in set -binary_op( +set_in_op = binary_op( name="in", arg_types=[object_rprimitive, set_rprimitive], return_type=c_int_rprimitive, diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 0e437f4597eae..2f3c18e9c7311 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -221,12 +221,14 @@ def clear(self) -> None: pass def pop(self) -> T: pass def update(self, x: Iterable[S]) -> None: pass def __or__(self, s: Union[Set[S], FrozenSet[S]]) -> Set[Union[T, S]]: ... + def __xor__(self, s: Union[Set[S], FrozenSet[S]]) -> Set[Union[T, S]]: ... class frozenset(Generic[T]): def __init__(self, i: Optional[Iterable[T]] = None) -> None: pass def __iter__(self) -> Iterator[T]: pass def __len__(self) -> int: pass def __or__(self, s: Union[Set[S], FrozenSet[S]]) -> FrozenSet[Union[T, S]]: ... + def __xor__(self, s: Union[Set[S], FrozenSet[S]]) -> FrozenSet[Union[T, S]]: ... class slice: pass diff --git a/mypyc/test-data/irbuild-set.test b/mypyc/test-data/irbuild-set.test index fec76751c9154..c567422abac76 100644 --- a/mypyc/test-data/irbuild-set.test +++ b/mypyc/test-data/irbuild-set.test @@ -655,3 +655,185 @@ L0: r12 = PySet_Add(r0, r11) r13 = r12 >= 0 :: signed return r0 + +[case testOperatorInSetLiteral] +from typing_extensions import Final + +CONST: Final = "daylily" +non_const = 10 + +def precomputed(i: object) -> bool: + return i in {1, 2.0, 1 +2, 4j, "foo", b"bar", CONST, (None, (27,)), (), False} +def not_precomputed_non_final_name(i: int) -> bool: + return i in {non_const} +def not_precomputed_nested_set(i: int) -> bool: + return i in {frozenset({1}), 2} +[out] +def precomputed(i): + i :: object + r0 :: set + r1 :: int32 + r2 :: bit + r3 :: bool +L0: + r0 = frozenset({(), (None, (27,)), 1, 2.0, 3, 4j, False, b'bar', 'daylily', 'foo'}) + r1 = PySet_Contains(r0, i) + r2 = r1 >= 0 :: signed + r3 = truncate r1: int32 to builtins.bool + return r3 +def not_precomputed_non_final_name(i): + i :: int + r0 :: dict + r1 :: str + r2 :: object + r3 :: int + r4 :: set + r5 :: object + r6 :: int32 + r7 :: bit + r8 :: object + r9 :: int32 + r10 :: bit + r11 :: bool +L0: + r0 = __main__.globals :: static + r1 = 'non_const' + r2 = CPyDict_GetItem(r0, r1) + r3 = unbox(int, r2) + r4 = PySet_New(0) + r5 = box(int, r3) + r6 = PySet_Add(r4, r5) + r7 = r6 >= 0 :: signed + r8 = box(int, i) + r9 = PySet_Contains(r4, r8) + r10 = r9 >= 0 :: signed + r11 = truncate r9: int32 to builtins.bool + return r11 +def not_precomputed_nested_set(i): + i :: int + r0 :: set + r1 :: object + r2 :: int32 + r3 :: bit + r4 :: object + r5 :: set + r6 :: int32 + r7 :: bit + r8 :: object + r9 :: int32 + r10 :: bit + r11 :: object + r12 :: int32 + r13 :: bit + r14 :: bool +L0: + r0 = PySet_New(0) + r1 = object 1 + r2 = PySet_Add(r0, r1) + r3 = r2 >= 0 :: signed + r4 = PyFrozenSet_New(r0) + r5 = PySet_New(0) + r6 = PySet_Add(r5, r4) + r7 = r6 >= 0 :: signed + r8 = object 2 + r9 = PySet_Add(r5, r8) + r10 = r9 >= 0 :: signed + r11 = box(int, i) + r12 = PySet_Contains(r5, r11) + r13 = r12 >= 0 :: signed + r14 = truncate r12: int32 to builtins.bool + return r14 + +[case testForSetLiteral] +from typing_extensions import Final + +CONST: Final = 10 +non_const = 20 + +def precomputed() -> None: + for _ in {"None", "True", "False"}: + pass + +def precomputed2() -> None: + for _ in {None, False, 1, 2.0, "4", b"5", (6,), 7j, CONST, CONST + 1}: + pass + +def not_precomputed() -> None: + for not_optimized in {non_const}: + pass + +[out] +def precomputed(): + r0 :: set + r1, r2 :: object + r3 :: str + _ :: object + r4 :: bit +L0: + r0 = frozenset({'False', 'None', 'True'}) + r1 = PyObject_GetIter(r0) +L1: + r2 = PyIter_Next(r1) + if is_error(r2) goto L4 else goto L2 +L2: + r3 = cast(str, r2) + _ = r3 +L3: + goto L1 +L4: + r4 = CPy_NoErrOccured() +L5: + return 1 +def precomputed2(): + r0 :: set + r1, r2, _ :: object + r3 :: bit +L0: + r0 = frozenset({(6,), 1, 10, 11, 2.0, '4', 7j, False, None, b'5'}) + r1 = PyObject_GetIter(r0) +L1: + r2 = PyIter_Next(r1) + if is_error(r2) goto L4 else goto L2 +L2: + _ = r2 +L3: + goto L1 +L4: + r3 = CPy_NoErrOccured() +L5: + return 1 +def not_precomputed(): + r0 :: dict + r1 :: str + r2 :: object + r3 :: int + r4 :: set + r5 :: object + r6 :: int32 + r7 :: bit + r8, r9 :: object + r10, not_optimized :: int + r11 :: bit +L0: + r0 = __main__.globals :: static + r1 = 'non_const' + r2 = CPyDict_GetItem(r0, r1) + r3 = unbox(int, r2) + r4 = PySet_New(0) + r5 = box(int, r3) + r6 = PySet_Add(r4, r5) + r7 = r6 >= 0 :: signed + r8 = PyObject_GetIter(r4) +L1: + r9 = PyIter_Next(r8) + if is_error(r9) goto L4 else goto L2 +L2: + r10 = unbox(int, r9) + not_optimized = r10 +L3: + goto L1 +L4: + r11 = CPy_NoErrOccured() +L5: + return 1 + diff --git a/mypyc/test-data/run-sets.test b/mypyc/test-data/run-sets.test index 98ac92d569b70..56c946933fac4 100644 --- a/mypyc/test-data/run-sets.test +++ b/mypyc/test-data/run-sets.test @@ -115,3 +115,36 @@ from native import update s = {1, 2, 3} update(s, [5, 4, 3]) assert s == {1, 2, 3, 4, 5} + +[case testPrecomputedFrozenSets] +from typing import Any +from typing_extensions import Final + +CONST: Final = "CONST" +non_const = "non_const" + +def main_set(item: Any) -> bool: + return item in {None, False, 1, 2.0, "3", b"4", 5j, (6,), ((7,),), (), CONST} + +def main_negated_set(item: Any) -> bool: + return item not in {None, False, 1, 2.0, "3", b"4", 5j, (6,), ((7,),), (), CONST} + +def non_final_name_set(item: Any) -> bool: + return item in {non_const} + +s = set() +for i in {None, False, 1, 2.0, "3", b"4", 5j, (6,), CONST}: + s.add(i) + +def test_in_set() -> None: + for item in (None, False, 1, 2.0, "3", b"4", 5j, (6,), ((7,),), (), CONST): + assert main_set(item), f"{item!r} should be in set_main" + assert not main_negated_set(item), item + + assert non_final_name_set(non_const) + global non_const + non_const = "updated" + assert non_final_name_set("updated") + +def test_for_set() -> None: + assert not s ^ {None, False, 1, 2.0, "3", b"4", 5j, (6,), CONST}, s