Skip to content

Commit

Permalink
Support selecting TypedDicts from unions (#7184)
Browse files Browse the repository at this point in the history
It is a relatively common pattern to narrow down typed dicts from unions with non-dict types using `isinstance(x, dict)`. Currently mypy infers `Dict[Any, Any]` after such checks which is suboptimal.

I propose to special-case this in `narrow_declared_type()` and `restrict_subtype_away()`. Using this opportunity I factored out special cases from the latter in a separate helper function.

Using this opportunity I also fix an old type erasure bug in `isinstance()` checks (type should be erased after mapping to supertype, not before).
  • Loading branch information
ilevkivskyi committed Jul 9, 2019
1 parent fc4baa6 commit e479b6d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 22 deletions.
10 changes: 10 additions & 0 deletions mypy/meet.py
Expand Up @@ -54,6 +54,12 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
return TypeType.make_normalized(narrow_declared_type(declared.item, narrowed.item))
elif isinstance(declared, (Instance, TupleType, TypeType, LiteralType)):
return meet_types(declared, narrowed)
elif isinstance(declared, TypedDictType) and isinstance(narrowed, Instance):
# Special case useful for selecting TypedDicts from unions using isinstance(x, dict).
if (narrowed.type.fullname() == 'builtins.dict' and
all(isinstance(t, AnyType) for t in narrowed.args)):
return declared
return meet_types(declared, narrowed)
return narrowed


Expand Down Expand Up @@ -478,6 +484,8 @@ def visit_instance(self, t: Instance) -> Type:
return meet_types(t, self.s)
elif isinstance(self.s, LiteralType):
return meet_types(t, self.s)
elif isinstance(self.s, TypedDictType):
return meet_types(t, self.s)
return self.default(self.s)

def visit_callable_type(self, t: CallableType) -> Type:
Expand Down Expand Up @@ -555,6 +563,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
required_keys = t.required_keys | self.s.required_keys
return TypedDictType(items, required_keys, fallback)
elif isinstance(self.s, Instance) and is_subtype(t, self.s):
return t
else:
return self.default(self.s)

Expand Down
67 changes: 47 additions & 20 deletions mypy/subtypes.py
Expand Up @@ -1007,58 +1007,81 @@ def unify_generic_callable(type: CallableType, target: CallableType,


def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type:
"""Return t minus s.
"""Return t minus s for runtime type assertions.
If we can't determine a precise result, return a supertype of the
ideal result (just t is a valid result).
This is used for type inference of runtime type checks such as
isinstance.
Currently this just removes elements of a union type.
isinstance(). Currently this just removes elements of a union type.
"""
if isinstance(t, UnionType):
# Since runtime type checks will ignore type arguments, erase the types.
erased_s = erase_type(s)
# TODO: Implement more robust support for runtime isinstance() checks,
# see issue #3827
new_items = [item for item in t.relevant_items()
if (not (is_proper_subtype(erase_type(item), erased_s,
ignore_promotions=ignore_promotions) or
is_proper_subtype(item, erased_s,
ignore_promotions=ignore_promotions))
or isinstance(item, AnyType))]
if (isinstance(item, AnyType) or
not covers_at_runtime(item, s, ignore_promotions))]
return UnionType.make_union(new_items)
else:
return t


def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool:
def covers_at_runtime(item: Type, supertype: Type, ignore_promotions: bool) -> bool:
"""Will isinstance(item, supertype) always return True at runtime?"""
# Since runtime type checks will ignore type arguments, erase the types.
supertype = erase_type(supertype)
if is_proper_subtype(erase_type(item), supertype, ignore_promotions=ignore_promotions,
erase_instances=True):
return True
if isinstance(supertype, Instance) and supertype.type.is_protocol:
# TODO: Implement more robust support for runtime isinstance() checks, see issue #3827.
if is_proper_subtype(item, supertype, ignore_promotions=ignore_promotions):
return True
if isinstance(item, TypedDictType) and isinstance(supertype, Instance):
# Special case useful for selecting TypedDicts from unions using isinstance(x, dict).
if supertype.type.fullname() == 'builtins.dict':
return True
# TODO: Add more special cases.
return False


def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False,
erase_instances: bool = False) -> bool:
"""Is left a proper subtype of right?
For proper subtypes, there's no need to rely on compatibility due to
Any types. Every usable type is a proper subtype of itself.
If erase_instances is True, erase left instance *after* mapping it to supertype
(this is useful for runtime isinstance() checks).
"""
if isinstance(right, UnionType) and not isinstance(left, UnionType):
return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions)
return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions,
erase_instances=erase_instances)
for item in right.items])
return left.accept(ProperSubtypeVisitor(right, ignore_promotions=ignore_promotions))
return left.accept(ProperSubtypeVisitor(right, ignore_promotions=ignore_promotions,
erase_instances=erase_instances))


class ProperSubtypeVisitor(TypeVisitor[bool]):
def __init__(self, right: Type, *, ignore_promotions: bool = False) -> None:
def __init__(self, right: Type, *,
ignore_promotions: bool = False,
erase_instances: bool = False) -> None:
self.right = right
self.ignore_promotions = ignore_promotions
self.erase_instances = erase_instances
self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind(
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
)

@staticmethod
def build_subtype_kind(*, ignore_promotions: bool = False) -> SubtypeKind:
return (True, ignore_promotions)
def build_subtype_kind(*, ignore_promotions: bool = False,
erase_instances: bool = False) -> SubtypeKind:
return True, ignore_promotions, erase_instances

def _is_proper_subtype(self, left: Type, right: Type) -> bool:
return is_proper_subtype(left, right, ignore_promotions=self.ignore_promotions)
return is_proper_subtype(left, right,
ignore_promotions=self.ignore_promotions,
erase_instances=self.erase_instances)

def visit_unbound_type(self, left: UnboundType) -> bool:
# This can be called if there is a bad type annotation. The result probably
Expand Down Expand Up @@ -1107,6 +1130,10 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool:
return mypy.sametypes.is_same_type(leftarg, rightarg)
# Map left type to corresponding right instances.
left = map_instance_to_supertype(left, right.type)
if self.erase_instances:
erased = erase_type(left)
assert isinstance(erased, Instance)
left = erased

nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in
zip(left.args, right.args, right.type.defn.type_vars))
Expand Down
43 changes: 41 additions & 2 deletions test-data/unit/check-typeddict.test
Expand Up @@ -580,7 +580,6 @@ def g(x: X, y: M) -> None: pass
reveal_type(f(g)) # N: Revealed type is '<nothing>'
[builtins fixtures/dict.pyi]

# TODO: It would be more accurate for the meet to be TypedDict instead.
[case testMeetOfTypedDictWithCompatibleMappingSuperclassIsUninhabitedForNow]
# flags: --strict-optional
from mypy_extensions import TypedDict
Expand All @@ -590,7 +589,7 @@ I = Iterable[str]
T = TypeVar('T')
def f(x: Callable[[T, T], None]) -> T: pass
def g(x: X, y: I) -> None: pass
reveal_type(f(g)) # N: Revealed type is '<nothing>'
reveal_type(f(g)) # N: Revealed type is 'TypedDict('__main__.X', {'x': builtins.int})'
[builtins fixtures/dict.pyi]

[case testMeetOfTypedDictsWithNonTotal]
Expand Down Expand Up @@ -1838,3 +1837,43 @@ def func(x):
pass
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]

[case testTypedDictIsInstance]
from typing import TypedDict, Union

class User(TypedDict):
id: int
name: str

u: Union[str, User]
u2: User

if isinstance(u, dict):
reveal_type(u) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})'
else:
reveal_type(u) # N: Revealed type is 'builtins.str'

assert isinstance(u2, dict)
reveal_type(u2) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})'
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]

[case testTypedDictIsInstanceABCs]
from typing import TypedDict, Union, Mapping, Iterable

class User(TypedDict):
id: int
name: str

u: Union[int, User]
u2: User

if isinstance(u, Iterable):
reveal_type(u) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})'
else:
reveal_type(u) # N: Revealed type is 'builtins.int'

assert isinstance(u2, Mapping)
reveal_type(u2) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})'
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]

0 comments on commit e479b6d

Please sign in to comment.