Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Narrow types after 'in' operator #4072

Merged
merged 6 commits into from
Oct 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,39 @@ def remove_optional(typ: Type) -> Type:
return typ


def builtin_item_type(tp: Type) -> Optional[Type]:
"""Get the item type of a builtin container.

If 'tp' is not one of the built containers (these includes NamedTuple and TypedDict)
or if the container is not parameterized (like List or List[Any])
return None. This function is used to narrow optional types in situations like this:

x: Optional[int]
if x in (1, 2, 3):
x + 42 # OK

Note: this is only OK for built-in containers, where we know the behavior
of __contains__.
"""
if isinstance(tp, Instance):
if tp.type.fullname() in ['builtins.list', 'builtins.tuple', 'builtins.dict',
'builtins.set', 'builtins.frozenset']:
if not tp.args:
# TODO: fix tuple in lib-stub/builtins.pyi (it should be generic).
return None
if not isinstance(tp.args[0], AnyType):
return tp.args[0]
elif isinstance(tp, TupleType) and all(not isinstance(it, AnyType) for it in tp.items):
return UnionType.make_simplified_union(tp.items) # this type is not externally visible
elif isinstance(tp, TypedDictType):
# TypedDict always has non-optional string keys.
if tp.fallback.type.fullname() == 'typing.Mapping':
return tp.fallback.args[0]
elif tp.fallback.type.bases[0].type.fullname() == 'typing.Mapping':
return tp.fallback.type.bases[0].args[0]
return None


def and_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
"""Calculate what information we can learn from the truth of (e1 and e2)
in terms of the information that we can learn from the truth of e1 and
Expand Down Expand Up @@ -3020,6 +3053,20 @@ def find_isinstance_check(node: Expression,
optional_expr = node.operands[1]
if is_overlapping_types(optional_type, comp_type):
return {optional_expr: remove_optional(optional_type)}, {}
elif node.operators in [['in'], ['not in']]:
expr = node.operands[0]
left_type = type_map[expr]
right_type = builtin_item_type(type_map[node.operands[1]])
right_ok = right_type and (not is_optional(right_type) and
(not isinstance(right_type, Instance) or
right_type.type.fullname() != 'builtins.object'))
if (right_type and right_ok and is_optional(left_type) and
literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and
is_overlapping_types(left_type, right_type)):
if node.operators == ['in']:
return {expr: remove_optional(left_type)}, {}
if node.operators == ['not in']:
return {}, {expr: remove_optional(left_type)}
elif isinstance(node, RefExpr):
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
# respectively
Expand Down
207 changes: 206 additions & 1 deletion test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,6 @@ if isinstance(x, str, 1): # E: Too many arguments for "isinstance"
reveal_type(x) # E: Revealed type is 'builtins.int'
[builtins fixtures/isinstancelist.pyi]


[case testIsinstanceNarrowAny]
from typing import Any

Expand All @@ -1770,3 +1769,209 @@ def narrow_any_to_str_then_reassign_to_int() -> None:
reveal_type(v) # E: Revealed type is 'Any'

[builtins fixtures/isinstance.pyi]

[case testNarrowTypeAfterInList]
# flags: --strict-optional
from typing import List, Optional

x: List[int]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test for optional item type (e.g. List[Optional[int]]).

y: Optional[int]

if y in x:
reveal_type(y) # E: Revealed type is 'builtins.int'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
if y not in x:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'builtins.int'
[builtins fixtures/list.pyi]
[out]

[case testNarrowTypeAfterInListOfOptional]
# flags: --strict-optional
from typing import List, Optional

x: List[Optional[int]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's another special case: What if the container is a nested one, say List[List[x]], where x might be Any? How do we deal with various item types, such as List[int] and List[Any]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a) If the container is List[Any], then we do nothing (there is a test for this).
b) If the container is List[List[Any]] we narrow the type (provided there is an overlap, this is consistent with how == currently treated), for example:

x: Optional[int]
lst: Optional[List[int]]
nested: List[List[Any]]
if lst in nested:
    reveal_type(lst) # List[int]
if x in nested:
    reveal_type(x) # Optional[int]

There is already a test for non-overlapping items int vs str. I will add one more test for other (nested) types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, added one more test.

y: Optional[int]

if y not in x:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
[builtins fixtures/list.pyi]
[out]

[case testNarrowTypeAfterInListNonOverlapping]
# flags: --strict-optional
from typing import List, Optional

x: List[str]
y: Optional[int]

if y in x:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
[builtins fixtures/list.pyi]
[out]

[case testNarrowTypeAfterInListNested]
# flags: --strict-optional
from typing import List, Optional, Any

x: Optional[int]
lst: Optional[List[int]]
nested_any: List[List[Any]]

if lst in nested_any:
reveal_type(lst) # E: Revealed type is 'builtins.list[builtins.int]'
if x in nested_any:
reveal_type(x) # E: Revealed type is 'Union[builtins.int, builtins.None]'
[builtins fixtures/list.pyi]
[out]

[case testNarrowTypeAfterInTuple]
# flags: --strict-optional
from typing import Optional
class A: pass
class B(A): pass
class C(A): pass

y: Optional[B]
if y in (B(), C()):
reveal_type(y) # E: Revealed type is '__main__.B'
else:
reveal_type(y) # E: Revealed type is 'Union[__main__.B, builtins.None]'
[builtins fixtures/tuple.pyi]
[out]

[case testNarrowTypeAfterInNamedTuple]
# flags: --strict-optional
from typing import NamedTuple, Optional
class NT(NamedTuple):
x: int
y: int
nt: NT

y: Optional[int]
if y not in nt:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'builtins.int'
[builtins fixtures/tuple.pyi]
[out]

[case testNarrowTypeAfterInDict]
# flags: --strict-optional
from typing import Dict, Optional
x: Dict[str, int]
y: Optional[str]

if y in x:
reveal_type(y) # E: Revealed type is 'builtins.str'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]'
if y not in x:
reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'builtins.str'
[builtins fixtures/dict.pyi]
[out]

[case testNarrowTypeAfterInList_python2]
# flags: --strict-optional
from typing import List, Optional

x = [] # type: List[int]
y = None # type: Optional[int]

# TODO: Fix running tests on Python 2: "Iterator[int]" has no attribute "next"
if y in x: # type: ignore
reveal_type(y) # E: Revealed type is 'builtins.int'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
if y not in x: # type: ignore
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'builtins.int'

[builtins_py2 fixtures/python2.pyi]
[out]

[case testNarrowTypeAfterInNoAnyOrObject]
# flags: --strict-optional
from typing import Any, List, Optional
x: List[Any]
z: List[object]

y: Optional[int]
if y in x:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'

if y not in z:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
[typing fixtures/typing-full.pyi]
[builtins fixtures/list.pyi]
[out]

[case testNarrowTypeAfterInUserDefined]
# flags: --strict-optional
from typing import Container, Optional

class C(Container[int]):
def __contains__(self, item: object) -> bool:
return item is 'surprise'

y: Optional[int]
# We never trust user defined types
if y in C():
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
if y not in C():
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
[typing fixtures/typing-full.pyi]
[builtins fixtures/list.pyi]
[out]

[case testNarrowTypeAfterInSet]
# flags: --strict-optional
from typing import Optional, Set
s: Set[str]

y: Optional[str]
if y in {'a', 'b', 'c'}:
reveal_type(y) # E: Revealed type is 'builtins.str'
else:
reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]'
if y not in s:
reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]'
else:
reveal_type(y) # E: Revealed type is 'builtins.str'
[builtins fixtures/set.pyi]
[out]

[case testNarrowTypeAfterInTypedDict]
# flags: --strict-optional
from typing import Optional
from mypy_extensions import TypedDict
class TD(TypedDict):
a: int
b: str
td: TD

def f() -> None:
x: Optional[str]
if x not in td:
return
reveal_type(x) # E: Revealed type is 'builtins.str'
[typing fixtures/typing-full.pyi]
[builtins fixtures/dict.pyi]
[out]
1 change: 1 addition & 0 deletions test-data/unit/fixtures/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class dict(Generic[KT, VT]):
def __getitem__(self, key: KT) -> VT: pass
def __setitem__(self, k: KT, v: VT) -> None: pass
def __iter__(self) -> Iterator[KT]: pass
def __contains__(self, item: object) -> bool: pass
def update(self, a: Mapping[KT, VT]) -> None: pass
@overload
def get(self, k: KT) -> Optional[VT]: pass
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/list.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class list(Generic[T]):
@overload
def __init__(self, x: Iterable[T]) -> None: pass
def __iter__(self) -> Iterator[T]: pass
def __contains__(self, item: object) -> bool: pass
def __add__(self, x: list[T]) -> list[T]: pass
def __mul__(self, x: int) -> list[T]: pass
def __getitem__(self, x: int) -> T: pass
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/python2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class function: pass
class int: pass
class str: pass
class unicode: pass
class bool: pass

T = TypeVar('T')
class list(Iterable[T], Generic[T]): pass
Expand Down
2 changes: 2 additions & 0 deletions test-data/unit/fixtures/set.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ class function: pass

class int: pass
class str: pass
class bool: pass

class set(Iterable[T], Generic[T]):
def __iter__(self) -> Iterator[T]: pass
def __contains__(self, item: object) -> bool: pass
def add(self, x: T) -> None: pass
def discard(self, x: T) -> None: pass
def update(self, x: Set[T]) -> None: pass
1 change: 1 addition & 0 deletions test-data/unit/fixtures/tuple.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class type:
def __call__(self, *a) -> object: pass
class tuple(Sequence[Tco], Generic[Tco]):
def __iter__(self) -> Iterator[Tco]: pass
def __contains__(self, item: object) -> bool: pass
def __getitem__(self, x: int) -> Tco: pass
def count(self, obj: Any) -> int: pass
class function: pass
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/typing-full.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class Mapping(Iterable[T], Protocol[T, T_co]):
def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: pass
def values(self) -> Iterable[T_co]: pass # Approximate return type
def __len__(self) -> int: ...
def __contains__(self, arg: object) -> int: pass

@runtime
class MutableMapping(Mapping[T, U], Protocol):
Expand Down