diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 77b4a1d7c721..6bd947b6df5d 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -399,6 +399,16 @@ def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'RTuple': # Exception tuple: (exception class, exception instance, traceback object) exc_rtuple = RTuple([object_rprimitive, object_rprimitive, object_rprimitive]) +# Dictionary iterator tuple: (should continue, internal offset, key, value) +# See mypyc.irbuild.for_helpers.ForDictionaryCommon for more details. +dict_next_rtuple_pair = RTuple( + [bool_rprimitive, int_rprimitive, object_rprimitive, object_rprimitive] +) +# Same as above but just for key or value. +dict_next_rtuple_single = RTuple( + [bool_rprimitive, int_rprimitive, object_rprimitive] +) + class RInstance(RType): """Instance of user-defined class (compiled to C extension class). diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 673c05cda88b..71bc272eccd0 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -24,6 +24,7 @@ from mypy.types import ( Type, Instance, TupleType, UninhabitedType, get_proper_type ) +from mypy.maptype import map_instance_to_supertype from mypy.visitor import ExpressionVisitor, StatementVisitor from mypy.util import split_target @@ -604,6 +605,30 @@ def get_sequence_type(self, expr: Expression) -> RType: else: return self.type_to_rtype(target_type.args[0]) + def get_dict_base_type(self, expr: Expression) -> Instance: + """Find dict type of a dict-like expression. + + This is useful for dict subclasses like SymbolTable. + """ + target_type = get_proper_type(self.types[expr]) + assert isinstance(target_type, Instance) + dict_base = next(base for base in target_type.type.mro + if base.fullname == 'builtins.dict') + return map_instance_to_supertype(target_type, dict_base) + + def get_dict_key_type(self, expr: Expression) -> RType: + dict_base_type = self.get_dict_base_type(expr) + return self.type_to_rtype(dict_base_type.args[0]) + + def get_dict_value_type(self, expr: Expression) -> RType: + dict_base_type = self.get_dict_base_type(expr) + return self.type_to_rtype(dict_base_type.args[1]) + + def get_dict_item_type(self, expr: Expression) -> RType: + key_type = self.get_dict_key_type(expr) + value_type = self.get_dict_value_type(expr) + return RTuple([key_type, value_type]) + def _analyze_iterable_item_type(self, expr: Expression) -> Type: """Return the item type given by 'expr' in an iterable context.""" # This logic is copied from mypy's TypeChecker.analyze_iterable_item_type. diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 7c1ad443150f..700a868d2039 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -6,13 +6,22 @@ """ from typing import Union, List, Optional, Tuple, Callable +from typing_extensions import Type, ClassVar -from mypy.nodes import Lvalue, Expression, TupleExpr, CallExpr, RefExpr, GeneratorExpr, ARG_POS +from mypy.nodes import ( + Lvalue, Expression, TupleExpr, CallExpr, RefExpr, GeneratorExpr, ARG_POS, MemberExpr +) from mypyc.ir.ops import ( - Value, BasicBlock, LoadInt, Branch, Register, AssignmentTarget + Value, BasicBlock, LoadInt, Branch, Register, AssignmentTarget, TupleGet, + AssignmentTargetTuple, TupleSet, OpDescription ) from mypyc.ir.rtypes import ( - RType, is_short_int_rprimitive, is_list_rprimitive, is_sequence_rprimitive + RType, is_short_int_rprimitive, is_list_rprimitive, is_sequence_rprimitive, + RTuple, is_dict_rprimitive +) +from mypyc.primitives.dict_ops import ( + dict_next_key_op, dict_next_value_op, dict_next_item_op, dict_check_size_op, + dict_key_iter_op, dict_value_iter_op, dict_item_iter_op ) from mypyc.primitives.int_ops import unsafe_short_add from mypyc.primitives.list_ops import new_list_op, list_append_op, list_get_item_unsafe_op @@ -170,6 +179,15 @@ def make_for_loop_generator(builder: IRBuilder, for_list.init(expr_reg, target_type, reverse=False) return for_list + if is_dict_rprimitive(rtyp): + # Special case "for k in ". + expr_reg = builder.accept(expr) + target_type = builder.get_dict_key_type(expr) + + for_dict = ForDictionaryKeys(builder, index, body_block, loop_exit, line, nested) + for_dict.init(expr_reg, target_type) + return for_dict + if (isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr)): if (expr.callee.fullname == 'builtins.range' @@ -233,6 +251,27 @@ def make_for_loop_generator(builder: IRBuilder, for_list = ForSequence(builder, index, body_block, loop_exit, line, nested) for_list.init(expr_reg, target_type, reverse=True) return for_list + if (isinstance(expr, CallExpr) + and isinstance(expr.callee, MemberExpr) + and not expr.args): + # Special cases for dictionary iterator methods, like dict.items(). + rtype = builder.node_type(expr.callee.expr) + if (is_dict_rprimitive(rtype) + and expr.callee.name in ('keys', 'values', 'items')): + expr_reg = builder.accept(expr.callee.expr) + for_dict_type = None # type: Optional[Type[ForGenerator]] + if expr.callee.name == 'keys': + target_type = builder.get_dict_key_type(expr.callee.expr) + for_dict_type = ForDictionaryKeys + elif expr.callee.name == 'values': + target_type = builder.get_dict_value_type(expr.callee.expr) + for_dict_type = ForDictionaryValues + else: + target_type = builder.get_dict_item_type(expr.callee.expr) + for_dict_type = ForDictionaryItems + for_dict_gen = for_dict_type(builder, index, body_block, loop_exit, line, nested) + for_dict_gen.init(expr_reg, target_type) + return for_dict_gen # Default to a generic for loop. expr_reg = builder.accept(expr) @@ -292,6 +331,14 @@ def gen_step(self) -> None: def gen_cleanup(self) -> None: """Generate post-loop cleanup (if needed).""" + def load_len(self, expr: Union[Value, AssignmentTarget]) -> Value: + """A helper to get collection length, used by several subclasses.""" + return self.builder.builder.builtin_call( + [self.builder.read(expr, self.line)], + 'builtins.len', + self.line, + ) + class ForIterable(ForGenerator): """Generate IR for a for loop over an arbitrary iterable (the normal case).""" @@ -371,17 +418,11 @@ def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None: if not reverse: index_reg = builder.add(LoadInt(0)) else: - index_reg = builder.binary_op(self.load_len(), builder.add(LoadInt(1)), '-', self.line) + index_reg = builder.binary_op(self.load_len(self.expr_target), + builder.add(LoadInt(1)), '-', self.line) self.index_target = builder.maybe_spill_assignable(index_reg) self.target_type = target_type - def load_len(self) -> Value: - return self.builder.builder.builtin_call( - [self.builder.read(self.expr_target, self.line)], - 'builtins.len', - self.line, - ) - def gen_condition(self) -> None: builder = self.builder line = self.line @@ -398,7 +439,7 @@ def gen_condition(self) -> None: builder.activate_block(second_check) # For compatibility with python semantics we recalculate the length # at every iteration. - len_reg = self.load_len() + len_reg = self.load_len(self.expr_target) comparison = builder.binary_op(builder.read(self.index_target, line), len_reg, '<', line) builder.add_bool_branch(comparison, self.body_block, self.loop_exit) @@ -430,6 +471,136 @@ def gen_step(self) -> None: builder.add(LoadInt(step))], line), line) +class ForDictionaryCommon(ForGenerator): + """Generate optimized IR for a for loop over dictionary keys/values. + + The logic is pretty straightforward, we use PyDict_Next() API wrapped in + a tuple, so that we can modify only a single register. The layout of the tuple: + * f0: are there more items (bool) + * f1: current offset (int) + * f2: next key (object) + * f3: next value (object) + For more info see https://docs.python.org/3/c-api/dict.html#c.PyDict_Next. + + Note that for subclasses we fall back to generic PyObject_GetIter() logic, + since they may override some iteration methods in subtly incompatible manner. + The fallback logic is implemented in CPy.h via dynamic type check. + """ + dict_next_op = None # type: ClassVar[OpDescription] + dict_iter_op = None # type: ClassVar[OpDescription] + + def need_cleanup(self) -> bool: + # Technically, a dict subclass can raise an unrelated exception + # in __next__(), so we need this. + return True + + def init(self, expr_reg: Value, target_type: RType) -> None: + builder = self.builder + self.target_type = target_type + + # We add some variables to environment class, so they can be read across yield. + self.expr_target = builder.maybe_spill(expr_reg) + offset_reg = builder.add(LoadInt(0)) + self.offset_target = builder.maybe_spill_assignable(offset_reg) + self.size = builder.maybe_spill(self.load_len(self.expr_target)) + + # For dict class (not a subclass) this is the dictionary itself. + iter_reg = builder.primitive_op(self.dict_iter_op, [expr_reg], self.line) + self.iter_target = builder.maybe_spill(iter_reg) + + def gen_condition(self) -> None: + """Get next key/value pair, set new offset, and check if we should continue.""" + builder = self.builder + line = self.line + self.next_tuple = self.builder.primitive_op( + self.dict_next_op, [builder.read(self.iter_target, line), + builder.read(self.offset_target, line)], line) + + # Do this here instead of in gen_step() to minimize variables in environment. + new_offset = builder.add(TupleGet(self.next_tuple, 1, line)) + builder.assign(self.offset_target, new_offset, line) + + should_continue = builder.add(TupleGet(self.next_tuple, 0, line)) + builder.add( + Branch(should_continue, self.body_block, self.loop_exit, Branch.BOOL_EXPR) + ) + + def gen_step(self) -> None: + """Check that dictionary didn't change size during iteration. + + Raise RuntimeError if it is not the case to match CPython behavior. + """ + builder = self.builder + line = self.line + # Technically, we don't need a new primitive for this, but it is simpler. + builder.primitive_op(dict_check_size_op, + [builder.read(self.expr_target, line), + builder.read(self.size, line)], line) + + def gen_cleanup(self) -> None: + # Same as for generic ForIterable. + self.builder.primitive_op(no_err_occurred_op, [], self.line) + + +class ForDictionaryKeys(ForDictionaryCommon): + """Generate optimized IR for a for loop over dictionary keys.""" + dict_next_op = dict_next_key_op + dict_iter_op = dict_key_iter_op + + def begin_body(self) -> None: + builder = self.builder + line = self.line + + # Key is stored at the third place in the tuple. + key = builder.add(TupleGet(self.next_tuple, 2, line)) + builder.assign(builder.get_assignment_target(self.index), + builder.coerce(key, self.target_type, line), line) + + +class ForDictionaryValues(ForDictionaryCommon): + """Generate optimized IR for a for loop over dictionary values.""" + dict_next_op = dict_next_value_op + dict_iter_op = dict_value_iter_op + + def begin_body(self) -> None: + builder = self.builder + line = self.line + + # Value is stored at the third place in the tuple. + value = builder.add(TupleGet(self.next_tuple, 2, line)) + builder.assign(builder.get_assignment_target(self.index), + builder.coerce(value, self.target_type, line), line) + + +class ForDictionaryItems(ForDictionaryCommon): + """Generate optimized IR for a for loop over dictionary items.""" + dict_next_op = dict_next_item_op + dict_iter_op = dict_item_iter_op + + def begin_body(self) -> None: + builder = self.builder + line = self.line + + key = builder.add(TupleGet(self.next_tuple, 2, line)) + value = builder.add(TupleGet(self.next_tuple, 3, line)) + + # Coerce just in case e.g. key is itself a tuple to be unpacked. + assert isinstance(self.target_type, RTuple) + key = builder.coerce(key, self.target_type.types[0], line) + value = builder.coerce(value, self.target_type.types[1], line) + + target = builder.get_assignment_target(self.index) + if isinstance(target, AssignmentTargetTuple): + # Simpler code for common case: for k, v in d.items(). + if len(target.items) != 2: + builder.error("Expected a pair for dict item iteration", line) + builder.assign(target.items[0], key, line) + builder.assign(target.items[1], value, line) + else: + rvalue = builder.add(TupleSet([key, value], line)) + builder.assign(target, rvalue, line) + + class ForRange(ForGenerator): """Generate optimized IR for a for loop over an integer range.""" diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 6a6c5837b975..5e25eb9ca35a 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -1377,6 +1377,183 @@ static tuple_T3OOO CPy_GetExcInfo(void) { return ret; } +static PyObject *CPyDict_GetKeysIter(PyObject *dict) { + if (PyDict_CheckExact(dict)) { + // Return dict itself to indicate we can use fast path instead. + Py_INCREF(dict); + return dict; + } + return PyObject_GetIter(dict); +} + +static PyObject *CPyDict_GetItemsIter(PyObject *dict) { + if (PyDict_CheckExact(dict)) { + // Return dict itself to indicate we can use fast path instead. + Py_INCREF(dict); + return dict; + } + PyObject *view = PyObject_CallMethod(dict, "items", NULL); + if (view == NULL) { + return NULL; + } + PyObject *iter = PyObject_GetIter(view); + Py_DECREF(view); + return iter; +} + +static PyObject *CPyDict_GetValuesIter(PyObject *dict) { + if (PyDict_CheckExact(dict)) { + // Return dict itself to indicate we can use fast path instead. + Py_INCREF(dict); + return dict; + } + PyObject *view = PyObject_CallMethod(dict, "values", NULL); + if (view == NULL) { + return NULL; + } + PyObject *iter = PyObject_GetIter(view); + Py_DECREF(view); + return iter; +} + +// Our return tuple wrapper for dictionary iteration helper. +#ifndef MYPYC_DECLARED_tuple_T3CIO +#define MYPYC_DECLARED_tuple_T3CIO +typedef struct tuple_T3CIO { + char f0; // Should continue? + CPyTagged f1; // Last dict offset + PyObject *f2; // Next dictionary key or value +} tuple_T3CIO; +static tuple_T3CIO tuple_undefined_T3CIO = { 2, CPY_INT_TAG, NULL }; +#endif + +// Same as above but for both key and value. +#ifndef MYPYC_DECLARED_tuple_T4CIOO +#define MYPYC_DECLARED_tuple_T4CIOO +typedef struct tuple_T4CIOO { + char f0; // Should continue? + CPyTagged f1; // Last dict offset + PyObject *f2; // Next dictionary key + PyObject *f3; // Next dictionary value +} tuple_T4CIOO; +static tuple_T4CIOO tuple_undefined_T4CIOO = { 2, CPY_INT_TAG, NULL, NULL }; +#endif + +static void _CPyDict_FromNext(tuple_T3CIO *ret, PyObject *dict_iter) { + // Get next item from iterator and set "should continue" flag. + ret->f2 = PyIter_Next(dict_iter); + if (ret->f2 == NULL) { + ret->f0 = 0; + Py_INCREF(Py_None); + ret->f2 = Py_None; + } else { + ret->f0 = 1; + } +} + +// Helpers for fast dictionary iteration, return a single tuple +// instead of writing to multiple registers, for exact dicts use +// the fast path, and fall back to generic iterator logic for subclasses. +static tuple_T3CIO CPyDict_NextKey(PyObject *dict_or_iter, CPyTagged offset) { + tuple_T3CIO ret; + Py_ssize_t py_offset = CPyTagged_AsSsize_t(offset); + PyObject *dummy; + + if (PyDict_CheckExact(dict_or_iter)) { + ret.f0 = PyDict_Next(dict_or_iter, &py_offset, &ret.f2, &dummy); + if (ret.f0) { + ret.f1 = CPyTagged_FromSsize_t(py_offset); + } else { + // Set key to None, so mypyc can manage refcounts. + ret.f1 = 0; + ret.f2 = Py_None; + } + // PyDict_Next() returns borrowed references. + Py_INCREF(ret.f2); + } else { + // offset is dummy in this case, just use the old value. + ret.f1 = offset; + _CPyDict_FromNext(&ret, dict_or_iter); + } + return ret; +} + +static tuple_T3CIO CPyDict_NextValue(PyObject *dict_or_iter, CPyTagged offset) { + tuple_T3CIO ret; + Py_ssize_t py_offset = CPyTagged_AsSsize_t(offset); + PyObject *dummy; + + if (PyDict_CheckExact(dict_or_iter)) { + ret.f0 = PyDict_Next(dict_or_iter, &py_offset, &dummy, &ret.f2); + if (ret.f0) { + ret.f1 = CPyTagged_FromSsize_t(py_offset); + } else { + // Set value to None, so mypyc can manage refcounts. + ret.f1 = 0; + ret.f2 = Py_None; + } + // PyDict_Next() returns borrowed references. + Py_INCREF(ret.f2); + } else { + // offset is dummy in this case, just use the old value. + ret.f1 = offset; + _CPyDict_FromNext(&ret, dict_or_iter); + } + return ret; +} + +static tuple_T4CIOO CPyDict_NextItem(PyObject *dict_or_iter, CPyTagged offset) { + tuple_T4CIOO ret; + Py_ssize_t py_offset = CPyTagged_AsSsize_t(offset); + + if (PyDict_CheckExact(dict_or_iter)) { + ret.f0 = PyDict_Next(dict_or_iter, &py_offset, &ret.f2, &ret.f3); + if (ret.f0) { + ret.f1 = CPyTagged_FromSsize_t(py_offset); + } else { + // Set key and value to None, so mypyc can manage refcounts. + ret.f1 = 0; + ret.f2 = Py_None; + ret.f3 = Py_None; + } + } else { + ret.f1 = offset; + PyObject *item = PyIter_Next(dict_or_iter); + if (item == NULL || !PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) { + if (item != NULL) { + PyErr_SetString(PyExc_TypeError, "a tuple of length 2 expected"); + } + ret.f0 = 0; + ret.f2 = Py_None; + ret.f3 = Py_None; + } else { + ret.f0 = 1; + ret.f2 = PyTuple_GET_ITEM(item, 0); + ret.f3 = PyTuple_GET_ITEM(item, 1); + Py_DECREF(item); + } + } + // PyDict_Next() returns borrowed references. + Py_INCREF(ret.f2); + Py_INCREF(ret.f3); + return ret; +} + +// Check that dictionary didn't change size during iteration. +static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) { + if (!PyDict_CheckExact(dict)) { + // Dict subclasses will be checked by Python runtime. + return 1; + } + Py_ssize_t py_size = CPyTagged_AsSsize_t(size); + Py_ssize_t dict_size = PyDict_Size(dict); + if (py_size != dict_size) { + PyErr_SetString(PyExc_RuntimeError, "dictionary changed size during iteration"); + return 0; + } + return 1; +} + void CPy_Init(void); diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index cf55e462258c..8e9d23d2c489 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -3,7 +3,10 @@ from typing import List from mypyc.ir.ops import EmitterInterface, ERR_FALSE, ERR_MAGIC, ERR_NEVER -from mypyc.ir.rtypes import dict_rprimitive, object_rprimitive, bool_rprimitive, int_rprimitive +from mypyc.ir.rtypes import ( + dict_rprimitive, object_rprimitive, bool_rprimitive, int_rprimitive, + dict_next_rtuple_single, dict_next_rtuple_pair +) from mypyc.primitives.registry import ( name_ref_op, method_op, binary_op, func_op, custom_op, @@ -136,3 +139,61 @@ def emit_len(emitter: EmitterInterface, args: List[str], dest: str) -> None: result_type=int_rprimitive, error_kind=ERR_NEVER, emit=emit_len) + +# PyDict_Next() fast iteration +dict_key_iter_op = custom_op( + name='key_iter', + arg_types=[dict_rprimitive], + result_type=object_rprimitive, + error_kind=ERR_MAGIC, + emit=call_emit('CPyDict_GetKeysIter'), +) + +dict_value_iter_op = custom_op( + name='value_iter', + arg_types=[dict_rprimitive], + result_type=object_rprimitive, + error_kind=ERR_MAGIC, + emit=call_emit('CPyDict_GetValuesIter'), +) + +dict_item_iter_op = custom_op( + name='item_iter', + arg_types=[dict_rprimitive], + result_type=object_rprimitive, + error_kind=ERR_MAGIC, + emit=call_emit('CPyDict_GetItemsIter'), +) + +dict_next_key_op = custom_op( + arg_types=[object_rprimitive, int_rprimitive], + result_type=dict_next_rtuple_single, + error_kind=ERR_NEVER, + emit=call_emit('CPyDict_NextKey'), + format_str='{dest} = next_key {args[0]}, offset={args[1]}', +) + +dict_next_value_op = custom_op( + arg_types=[object_rprimitive, int_rprimitive], + result_type=dict_next_rtuple_single, + error_kind=ERR_NEVER, + emit=call_emit('CPyDict_NextValue'), + format_str='{dest} = next_value {args[0]}, offset={args[1]}', +) + +dict_next_item_op = custom_op( + arg_types=[object_rprimitive, int_rprimitive], + result_type=dict_next_rtuple_pair, + error_kind=ERR_NEVER, + emit=call_emit('CPyDict_NextItem'), + format_str='{dest} = next_item {args[0]}, offset={args[1]}', +) + +# check that len(dict) == const during iteration +dict_check_size_op = custom_op( + arg_types=[dict_rprimitive, int_rprimitive], + result_type=bool_rprimitive, + error_kind=ERR_FALSE, + emit=call_emit('CPyDict_CheckSize'), + format_str='{dest} = assert size({args[0]}) == {args[1]}', +) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index d5517b3f9d8a..9218734b0e76 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -136,7 +136,10 @@ def update(self, __m: Iterable[Tuple[K, V]], **kwargs: V) -> None: ... @overload def update(self, **kwargs: V) -> None: ... def pop(self, x: int) -> K: pass - def keys(self) -> List[K]: pass + def keys(self) -> Iterable[K]: pass + def values(self) -> Iterable[V]: pass + def items(self) -> Iterable[Tuple[K, V]]: pass + def clear(self) -> None: pass class set(Generic[T]): def __init__(self, i: Optional[Iterable[T]] = None) -> None: pass diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index c2ef9b789e63..1dd0b315dc82 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -151,29 +151,43 @@ def increment(d: Dict[str, int]) -> Dict[str, int]: [out] def increment(d): d :: dict - r0, r1 :: object - r2, k :: str + r0, r1 :: short_int + r2 :: int r3 :: object - r4 :: short_int - r5, r6 :: object - r7, r8 :: bool + r4 :: tuple[bool, int, object] + r5 :: int + r6 :: bool + r7 :: object + k, r8 :: str + r9 :: object + r10 :: short_int + r11, r12 :: object + r13, r14, r15 :: bool L0: - r0 = iter d :: object + r0 = 0 + r1 = r0 + r2 = len d :: dict + r3 = key_iter d :: dict L1: - r1 = next r0 :: object - if is_error(r1) goto L4 else goto L2 + r4 = next_key r3, offset=r1 + r5 = r4[1] + r1 = r5 + r6 = r4[0] + if r6 goto L2 else goto L4 :: bool L2: - r2 = cast(str, r1) - k = r2 - r3 = d[k] :: dict - r4 = 1 - r5 = box(short_int, r4) - r6 = r3 += r5 - r7 = d.__setitem__(k, r6) :: dict + r7 = r4[2] + r8 = cast(str, r7) + k = r8 + r9 = d[k] :: dict + r10 = 1 + r11 = box(short_int, r10) + r12 = r9 += r11 + r13 = d.__setitem__(k, r12) :: dict L3: + r14 = assert size(d) == r2 goto L1 L4: - r8 = no_err_occurred + r15 = no_err_occurred L5: return d @@ -204,3 +218,97 @@ L0: r7 = r4.__setitem__(r1, r6) :: dict return r4 +[case testDictIterationMethods] +from typing import Dict +def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int]) -> None: + for v in d1.values(): + if v in d2: + return + for k, v in d2.items(): + d2[k] += v +[out] +def print_dict_methods(d1, d2): + d1, d2 :: dict + r0, r1 :: short_int + r2 :: int + r3 :: object + r4 :: tuple[bool, int, object] + r5 :: int + r6 :: bool + r7 :: object + v, r8 :: int + r9 :: object + r10 :: bool + r11 :: None + r12, r13 :: bool + r14, r15 :: short_int + r16 :: int + r17 :: object + r18 :: tuple[bool, int, object, object] + r19 :: int + r20 :: bool + r21, r22 :: object + r23, r24, k :: int + r25, r26, r27, r28, r29 :: object + r30, r31, r32 :: bool + r33 :: None +L0: + r0 = 0 + r1 = r0 + r2 = len d1 :: dict + r3 = value_iter d1 :: dict +L1: + r4 = next_value r3, offset=r1 + r5 = r4[1] + r1 = r5 + r6 = r4[0] + if r6 goto L2 else goto L6 :: bool +L2: + r7 = r4[2] + r8 = unbox(int, r7) + v = r8 + r9 = box(int, v) + r10 = r9 in d2 :: dict + if r10 goto L3 else goto L4 :: bool +L3: + r11 = None + return r11 +L4: +L5: + r12 = assert size(d1) == r2 + goto L1 +L6: + r13 = no_err_occurred +L7: + r14 = 0 + r15 = r14 + r16 = len d2 :: dict + r17 = item_iter d2 :: dict +L8: + r18 = next_item r17, offset=r15 + r19 = r18[1] + r15 = r19 + r20 = r18[0] + if r20 goto L9 else goto L11 :: bool +L9: + r21 = r18[2] + r22 = r18[3] + r23 = unbox(int, r21) + r24 = unbox(int, r22) + k = r23 + v = r24 + r25 = box(int, k) + r26 = d2[r25] :: dict + r27 = box(int, v) + r28 = r26 += r27 + r29 = box(int, k) + r30 = d2.__setitem__(r29, r28) :: dict +L10: + r31 = assert size(d2) == r16 + goto L8 +L11: + r32 = no_err_occurred +L12: + r33 = None + return r33 + diff --git a/mypyc/test-data/irbuild-statements.test b/mypyc/test-data/irbuild-statements.test index add20669ed38..18e829a2eeb2 100644 --- a/mypyc/test-data/irbuild-statements.test +++ b/mypyc/test-data/irbuild-statements.test @@ -296,30 +296,44 @@ def f(d: Dict[int, int]) -> None: [out] def f(d): d :: dict - r0, r1 :: object - r2, key :: int - r3, r4 :: object + r0, r1 :: short_int + r2 :: int + r3 :: object + r4 :: tuple[bool, int, object] r5 :: int r6 :: bool - r7 :: None + r7 :: object + key, r8 :: int + r9, r10 :: object + r11 :: int + r12, r13 :: bool + r14 :: None L0: - r0 = iter d :: object + r0 = 0 + r1 = r0 + r2 = len d :: dict + r3 = key_iter d :: dict L1: - r1 = next r0 :: object - if is_error(r1) goto L4 else goto L2 + r4 = next_key r3, offset=r1 + r5 = r4[1] + r1 = r5 + r6 = r4[0] + if r6 goto L2 else goto L4 :: bool L2: - r2 = unbox(int, r1) - key = r2 - r3 = box(int, key) - r4 = d[r3] :: dict - r5 = unbox(int, r4) + r7 = r4[2] + r8 = unbox(int, r7) + key = r8 + r9 = box(int, key) + r10 = d[r9] :: dict + r11 = unbox(int, r10) L3: + r12 = assert size(d) == r2 goto L1 L4: - r6 = no_err_occurred + r13 = no_err_occurred L5: - r7 = None - return r7 + r14 = None + return r14 [case testForDictContinue] from typing import Dict @@ -336,47 +350,61 @@ def sum_over_even_values(d): d :: dict r0 :: short_int s :: int - r1, r2 :: object - r3, key :: int - r4, r5 :: object + r1, r2 :: short_int + r3 :: int + r4 :: object + r5 :: tuple[bool, int, object] r6 :: int - r7 :: short_int - r8 :: int - r9 :: short_int - r10 :: bool - r11, r12 :: object - r13, r14 :: int - r15 :: bool + r7 :: bool + r8 :: object + key, r9 :: int + r10, r11 :: object + r12 :: int + r13 :: short_int + r14 :: int + r15 :: short_int + r16 :: bool + r17, r18 :: object + r19, r20 :: int + r21, r22 :: bool L0: r0 = 0 s = r0 - r1 = iter d :: object + r1 = 0 + r2 = r1 + r3 = len d :: dict + r4 = key_iter d :: dict L1: - r2 = next r1 :: object - if is_error(r2) goto L6 else goto L2 + r5 = next_key r4, offset=r2 + r6 = r5[1] + r2 = r6 + r7 = r5[0] + if r7 goto L2 else goto L6 :: bool L2: - r3 = unbox(int, r2) - key = r3 - r4 = box(int, key) - r5 = d[r4] :: dict - r6 = unbox(int, r5) - r7 = 2 - r8 = r6 % r7 :: int - r9 = 0 - r10 = r8 != r9 :: int - if r10 goto L3 else goto L4 :: bool + r8 = r5[2] + r9 = unbox(int, r8) + key = r9 + r10 = box(int, key) + r11 = d[r10] :: dict + r12 = unbox(int, r11) + r13 = 2 + r14 = r12 % r13 :: int + r15 = 0 + r16 = r14 != r15 :: int + if r16 goto L3 else goto L4 :: bool L3: goto L5 L4: - r11 = box(int, key) - r12 = d[r11] :: dict - r13 = unbox(int, r12) - r14 = s + r13 :: int - s = r14 + r17 = box(int, key) + r18 = d[r17] :: dict + r19 = unbox(int, r18) + r20 = s + r19 :: int + s = r20 L5: + r21 = assert size(d) == r3 goto L1 L6: - r15 = no_err_occurred + r22 = no_err_occurred L7: return s diff --git a/mypyc/test-data/refcount.test b/mypyc/test-data/refcount.test index 8e9e7547cfa8..017811a1b0b4 100644 --- a/mypyc/test-data/refcount.test +++ b/mypyc/test-data/refcount.test @@ -788,36 +788,54 @@ def f(d: Dict[int, int]) -> None: [out] def f(d): d :: dict - r0, r1 :: object - r2, key :: int - r3, r4 :: object + r0, r1 :: short_int + r2 :: int + r3 :: object + r4 :: tuple[bool, int, object] r5 :: int r6 :: bool - r7 :: None + r7 :: object + key, r8 :: int + r9, r10 :: object + r11 :: int + r12, r13 :: bool + r14 :: None L0: - r0 = iter d :: object + r0 = 0 + r1 = r0 + r2 = len d :: dict + r3 = key_iter d :: dict L1: - r1 = next r0 :: object - if is_error(r1) goto L5 else goto L2 + r4 = next_key r3, offset=r1 + r5 = r4[1] + r1 = r5 + r6 = r4[0] + if r6 goto L2 else goto L6 :: bool L2: - r2 = unbox(int, r1) - dec_ref r1 - key = r2 - r3 = box(int, key) - r4 = d[r3] :: dict - dec_ref r3 - r5 = unbox(int, r4) + r7 = r4[2] dec_ref r4 - dec_ref r5 :: int - goto L1 + r8 = unbox(int, r7) + dec_ref r7 + key = r8 + r9 = box(int, key) + r10 = d[r9] :: dict + dec_ref r9 + r11 = unbox(int, r10) + dec_ref r10 + dec_ref r11 :: int L3: - r6 = no_err_occurred + r12 = assert size(d) == r2 + goto L1 L4: - r7 = None - return r7 + r13 = no_err_occurred L5: - dec_ref r0 - goto L3 + r14 = None + return r14 +L6: + dec_ref r2 :: int + dec_ref r3 + dec_ref r4 + goto L4 [case testBorrowRefs] def make_garbage(arg: object) -> None: diff --git a/mypyc/test-data/run.test b/mypyc/test-data/run.test index 64bbadfeb382..ffb3c0db2826 100644 --- a/mypyc/test-data/run.test +++ b/mypyc/test-data/run.test @@ -1065,6 +1065,106 @@ assert d == dict(object.__dict__) assert u(10) == 10 +[case testDictIterationMethodsRun] +from typing import Dict +def print_dict_methods(d1: Dict[int, int], + d2: Dict[int, int], + d3: Dict[int, int]) -> None: + for k in d1.keys(): + print(k) + for k, v in d2.items(): + print(k) + print(v) + for v in d3.values(): + print(v) + +def clear_during_iter(d: Dict[int, int]) -> None: + for k in d: + d.clear() + +class Custom(Dict[int, int]): pass +[file driver.py] +from native import print_dict_methods, Custom, clear_during_iter +from collections import OrderedDict +print_dict_methods({}, {}, {}) +print_dict_methods({1: 2}, {3: 4, 5: 6}, {7: 8}) +print('==') +c = Custom({0: 1}) +print_dict_methods(c, c, c) +print('==') +d = OrderedDict([(1, 2), (3, 4)]) +print_dict_methods(d, d, d) +print('==') +d.move_to_end(1) +print_dict_methods(d, d, d) +clear_during_iter({}) # OK +try: + clear_during_iter({1: 2, 3: 4}) +except RuntimeError as e: + assert str(e) == "dictionary changed size during iteration" +else: + assert False +try: + clear_during_iter(d) +except RuntimeError as e: + assert str(e) == "OrderedDict changed size during iteration" +else: + assert False + +class CustomMad(dict): + def __iter__(self): + return self + def __next__(self): + raise ValueError +m = CustomMad() +try: + clear_during_iter(m) +except ValueError: + pass +else: + assert False + +class CustomBad(dict): + def items(self): + return [(1, 2, 3)] # Oops +b = CustomBad() +try: + print_dict_methods(b, b, b) +except TypeError as e: + assert str(e) == "a tuple of length 2 expected" +else: + assert False +[out] +1 +3 +4 +5 +6 +8 +== +0 +0 +1 +1 +== +1 +3 +1 +2 +3 +4 +2 +4 +== +3 +1 +3 +4 +1 +2 +4 +2 + [case testPyMethodCall] from typing import List def f(x: List[int]) -> int: @@ -3517,7 +3617,7 @@ else: assert False [case testYield] -from typing import Generator, Iterable, Union, Tuple +from typing import Generator, Iterable, Union, Tuple, Dict def yield_three_times() -> Iterable[int]: yield 1 @@ -3578,6 +3678,17 @@ def yield_with_default(x: bool = False) -> Iterable[int]: if x: yield 0 +def yield_dict_methods(d1: Dict[int, int], + d2: Dict[int, int], + d3: Dict[int, int]) -> Iterable[int]: + for k in d1.keys(): + yield k + for k, v in d2.items(): + yield k + yield v + for v in d3.values(): + yield v + def three() -> int: return 3 @@ -3593,8 +3704,20 @@ def return_tuple() -> Generator[int, None, Tuple[int, int]]: return 1, 2 [file driver.py] -from native import yield_three_times, yield_twice_and_return, yield_while_loop, yield_for_loop, yield_with_except, complex_yield, yield_with_default, A, return_tuple +from native import ( + yield_three_times, + yield_twice_and_return, + yield_while_loop, + yield_for_loop, + yield_with_except, + complex_yield, + yield_with_default, + A, + return_tuple, + yield_dict_methods, +) from testutil import run_generator +from collections import defaultdict assert run_generator(yield_three_times()) == ((1, 2, 3), None) assert run_generator(yield_twice_and_return()) == ((1, 2), 4) @@ -3605,6 +3728,10 @@ assert run_generator(complex_yield(5, 'foo', 1.0)) == (('2 foo', 3, '4 foo'), 1. assert run_generator(yield_with_default()) == ((), None) assert run_generator(A(0).generator()) == ((0,), None) assert run_generator(return_tuple()) == ((0,), (1, 2)) +assert run_generator(yield_dict_methods({}, {}, {})) == ((), None) +assert run_generator(yield_dict_methods({1: 2}, {3: 4}, {5: 6})) == ((1, 3, 4, 6), None) +dd = defaultdict(int, {0: 1}) +assert run_generator(yield_dict_methods(dd, dd, dd)) == ((0, 0, 1, 1), None) for i in yield_twice_and_return(): print(i)