Skip to content

Commit

Permalink
[mypyc] Support iterating over a TypedDict (#14747)
Browse files Browse the repository at this point in the history
An optimization to make iterating over dict.keys(), dict.values() and
dict.items() faster caused mypyc to crash while compiling a TypedDict.
This commit fixes `Builder.get_dict_base_type` to properly handle
`TypedDictType`.

Fixes mypyc/mypyc#869.
  • Loading branch information
ichard26 committed Mar 20, 2023
1 parent 1a8ea61 commit 9944d5f
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 5 deletions.
9 changes: 7 additions & 2 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
ProperType,
TupleType,
Type,
TypedDictType,
TypeOfAny,
UninhabitedType,
UnionType,
Expand Down Expand Up @@ -913,8 +914,12 @@ def get_dict_base_type(self, expr: Expression) -> list[Instance]:

dict_types = []
for t in types:
assert isinstance(t, Instance), t
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
if isinstance(t, TypedDictType):
t = t.fallback
dict_base = next(base for base in t.type.mro if base.fullname == "typing.Mapping")
else:
assert isinstance(t, Instance), t
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
dict_types.append(map_instance_to_supertype(t, dict_base))
return dict_types

Expand Down
69 changes: 69 additions & 0 deletions mypyc/test-data/irbuild-dict.test
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ L0:

[case testDictIterationMethods]
from typing import Dict, Union
from typing_extensions import TypedDict

class Person(TypedDict):
name: str
age: int

def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int]) -> None:
for v in d1.values():
if v in d2:
Expand All @@ -229,6 +235,10 @@ def union_of_dicts(d: Union[Dict[str, int], Dict[str, str]]) -> None:
new = {}
for k, v in d.items():
new[k] = int(v)
def typeddict(d: Person) -> None:
for k, v in d.items():
if k == "name":
name = v
[out]
def print_dict_methods(d1, d2):
d1, d2 :: dict
Expand Down Expand Up @@ -370,6 +380,65 @@ L4:
r19 = CPy_NoErrOccured()
L5:
return 1
def typeddict(d):
d :: dict
r0 :: short_int
r1 :: native_int
r2 :: short_int
r3 :: object
r4 :: tuple[bool, short_int, object, object]
r5 :: short_int
r6 :: bool
r7, r8 :: object
r9, k :: str
v :: object
r10 :: str
r11 :: int32
r12 :: bit
r13 :: object
r14, r15, r16 :: bit
name :: object
r17, r18 :: bit
L0:
r0 = 0
r1 = PyDict_Size(d)
r2 = r1 << 1
r3 = CPyDict_GetItemsIter(d)
L1:
r4 = CPyDict_NextItem(r3, r0)
r5 = r4[1]
r0 = r5
r6 = r4[0]
if r6 goto L2 else goto L9 :: bool
L2:
r7 = r4[2]
r8 = r4[3]
r9 = cast(str, r7)
k = r9
v = r8
r10 = 'name'
r11 = PyUnicode_Compare(k, r10)
r12 = r11 == -1
if r12 goto L3 else goto L5 :: bool
L3:
r13 = PyErr_Occurred()
r14 = r13 != 0
if r14 goto L4 else goto L5 :: bool
L4:
r15 = CPy_KeepPropagating()
L5:
r16 = r11 == 0
if r16 goto L6 else goto L7 :: bool
L6:
name = v
L7:
L8:
r17 = CPyDict_CheckSize(d, r2)
goto L1
L9:
r18 = CPy_NoErrOccured()
L10:
return 1

[case testDictLoadAddress]
def f() -> None:
Expand Down
34 changes: 32 additions & 2 deletions mypyc/test-data/run-dicts.test
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ assert get_content_set(od) == ({1, 3}, {2, 4}, {(1, 2), (3, 4)})
[typing fixtures/typing-full.pyi]

[case testDictIterationMethodsRun]
from typing import Dict
from typing import Dict, Union
from typing_extensions import TypedDict

class ExtensionDict(TypedDict):
python: str
c: str

def print_dict_methods(d1: Dict[int, int],
d2: Dict[int, int],
d3: Dict[int, int]) -> None:
Expand All @@ -107,13 +113,27 @@ def print_dict_methods(d1: Dict[int, int],
for v in d3.values():
print(v)

def print_dict_methods_special(d1: Union[Dict[int, int], Dict[str, str]],
d2: ExtensionDict) -> None:
for k in d1.keys():
print(k)
for k, v in d1.items():
print(k)
print(v)
for v2 in d2.values():
print(v2)
for k2, v2 in d2.items():
print(k2)
print(v2)


def clear_during_iter(d: Dict[int, int]) -> None:
for k in d:
d.clear()

class Custom(Dict[int, int]): pass
[file driver.py]
from native import print_dict_methods, Custom, clear_during_iter
from native import print_dict_methods, print_dict_methods_special, Custom, clear_during_iter
from collections import OrderedDict
print_dict_methods({}, {}, {})
print_dict_methods({1: 2}, {3: 4, 5: 6}, {7: 8})
Expand All @@ -124,6 +144,7 @@ print('==')
d = OrderedDict([(1, 2), (3, 4)])
print_dict_methods(d, d, d)
print('==')
print_dict_methods_special({1: 2}, {"python": ".py", "c": ".c"})
d.move_to_end(1)
print_dict_methods(d, d, d)
clear_during_iter({}) # OK
Expand Down Expand Up @@ -185,6 +206,15 @@ else:
2
4
==
1
1
2
.py
.c
python
.py
c
.c
3
1
3
Expand Down
5 changes: 4 additions & 1 deletion test-data/unit/lib-stub/typing_extensions.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from typing import Any, Mapping, Iterator, NoReturn as NoReturn, Dict, Type
from typing import Any, Mapping, Iterable, Iterator, NoReturn as NoReturn, Dict, Tuple, Type
from typing import TYPE_CHECKING as TYPE_CHECKING
from typing import NewType as NewType, overload as overload

Expand Down Expand Up @@ -50,6 +50,9 @@ class _TypedDict(Mapping[str, object]):
# Mypy expects that 'default' has a type variable type.
def pop(self, k: NoReturn, default: _T = ...) -> object: ...
def update(self: _T, __m: _T) -> None: ...
def items(self) -> Iterable[Tuple[str, object]]: ...
def keys(self) -> Iterable[str]: ...
def values(self) -> Iterable[object]: ...
if sys.version_info < (3, 0):
def has_key(self, k: str) -> bool: ...
def __delitem__(self, k: NoReturn) -> None: ...
Expand Down

0 comments on commit 9944d5f

Please sign in to comment.