Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 110 additions & 75 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import functools
import itertools
from collections import defaultdict
from collections.abc import Iterable, Iterator, Mapping, Sequence, Set as AbstractSet
Expand Down Expand Up @@ -47,7 +48,12 @@
from mypy.expandtype import expand_type
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types
from mypy.meet import (
is_overlapping_erased_types,
is_overlapping_types,
meet_types,
narrow_declared_type,
)
from mypy.message_registry import ErrorMessage
from mypy.messages import (
SUGGESTED_TEST_FIXTURES,
Expand Down Expand Up @@ -6237,65 +6243,89 @@ def is_type_call(expr: CallExpr) -> bool:

# exprs that are being passed into type
exprs_in_type_calls: list[Expression] = []
# type that is being compared to type(expr)
type_being_compared: list[TypeRange] | None = None
# whether the type being compared to is final
# all the types that an expression will have if the overall expression is truthy
target_types: list[list[TypeRange]] = []
# only a single type can be used when passed directly (eg "str")
fixed_type: Type | None = None
# is this single type final?
is_final = False

def update_fixed_type(new_fixed_type: Type, new_is_final: bool) -> bool:
"""Returns if the update succeeds"""
nonlocal fixed_type, is_final
if update := (fixed_type is None or (is_same_type(new_fixed_type, fixed_type))):
fixed_type = new_fixed_type
is_final = new_is_final
return update

for index in expr_indices:
expr = node.operands[index]
proper_type = get_proper_type(self.lookup_type(expr))

if isinstance(expr, CallExpr) and is_type_call(expr):
exprs_in_type_calls.append(expr.args[0])
else:
current_type = self.get_isinstance_type(expr)
if current_type is None:
continue
if type_being_compared is not None:
# It doesn't really make sense to have several types being
# compared to the output of type (like type(x) == int == str)
# because whether that's true is solely dependent on what the
# types being compared are, so we don't try to narrow types any
# further because we can't really get any information about the
# type of x from that check
return {}, {}
else:
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo):
is_final = expr.node.is_final
type_being_compared = current_type
arg = expr.args[0]
exprs_in_type_calls.append(arg)
elif (
isinstance(expr, OpExpr)
or isinstance(proper_type, TupleType)
or is_named_instance(proper_type, "builtins.tuple")
):
# not valid for type comparisons, but allowed for isinstance checks
fixed_type = UninhabitedType()
continue

type_range = self.get_isinstance_type(expr)
if type_range is not None:
target_types.append(type_range)
if (
isinstance(expr, RefExpr)
and isinstance(expr.node, TypeInfo)
and len(type_range) == 1
):
if not update_fixed_type(
Instance(
expr.node,
[AnyType(TypeOfAny.special_form)] * len(expr.node.defn.type_vars),
),
expr.node.is_final,
):
return None, {}

if not exprs_in_type_calls:
return {}, {}

if_maps: list[TypeMap] = []
else_maps: list[TypeMap] = []
if_maps = []
else_maps = []
for expr in exprs_in_type_calls:
current_if_type, current_else_type = self.conditional_types_with_intersection(
self.lookup_type(expr), type_being_compared, expr
)
current_if_map, current_else_map = conditional_types_to_typemaps(
expr, current_if_type, current_else_type
)
if_maps.append(current_if_map)
else_maps.append(current_else_map)
expr_type: Type = get_proper_type(self.lookup_type(expr))
for type_range in target_types:
restriction, _ = self.conditional_types_with_intersection(
expr_type, type_range, expr
)
if restriction is not None:
narrowed_type = get_proper_type(narrow_declared_type(expr_type, restriction))
# Cannot be guaranteed that this is unreachable, so use fallback type.
if isinstance(narrowed_type, UninhabitedType):
expr_type = restriction
else:
expr_type = narrowed_type
_, else_map = conditional_types_to_typemaps(
expr,
*self.conditional_types_with_intersection(
(self.lookup_type(expr)), (type_range), expr
),
)
else_maps.append(else_map)

def combine_maps(list_maps: list[TypeMap]) -> TypeMap:
"""Combine all typemaps in list_maps into one typemap"""
if all(m is None for m in list_maps):
return None
result_map = {}
for d in list_maps:
if d is not None:
result_map.update(d)
return result_map

if_map = combine_maps(if_maps)
# type(x) == T is only true when x has the same type as T, meaning
# that it can be false if x is an instance of a subclass of T. That means
# we can't do any narrowing in the else case unless T is final, in which
# case T can't be subclassed
if fixed_type and expr_type is not None:
expr_type = narrow_declared_type(expr_type, fixed_type)

if_map, _ = conditional_types_to_typemaps(expr, expr_type, None)
if_maps.append(if_map)

if_map = functools.reduce(and_conditional_maps, if_maps)
if is_final:
else_map = combine_maps(else_maps)
else_map = functools.reduce(or_conditional_maps, else_maps)
else:
else_map = {}
return if_map, else_map
Expand Down Expand Up @@ -7039,7 +7069,6 @@ def refine_away_none_in_comparison(
if_map, else_map = {}, {}

if not non_optional_types or (len(non_optional_types) != len(chain_indices)):

# Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
# convenient but is strictly not type-safe):
for i in narrowable_operand_indices:
Expand Down Expand Up @@ -7961,35 +7990,41 @@ def get_isinstance_type(self, expr: Expression) -> list[TypeRange] | None:
return None
return left + right
all_types = get_proper_types(flatten_types(self.lookup_type(expr)))
types: list[TypeRange] = []
type_ranges: list[TypeRange] = []
for typ in all_types:
if isinstance(typ, FunctionLike) and typ.is_type_obj():
# If a type is generic, `isinstance` can only narrow its variables to Any.
any_parameterized = fill_typevars_with_any(typ.type_object())
# Tuples may have unattended type variables among their items
if isinstance(any_parameterized, TupleType):
erased_type = erase_typevars(any_parameterized)
else:
erased_type = any_parameterized
types.append(TypeRange(erased_type, is_upper_bound=False))
elif isinstance(typ, TypeType):
# Type[A] means "any type that is a subtype of A" rather than "precisely type A"
# we indicate this by setting is_upper_bound flag
is_upper_bound = True
if isinstance(typ.item, NoneType):
# except for Type[None], because "'NoneType' is not an acceptable base type"
is_upper_bound = False
types.append(TypeRange(typ.item, is_upper_bound=is_upper_bound))
elif isinstance(typ, Instance) and typ.type.fullname == "builtins.type":
object_type = Instance(typ.type.mro[-1], [])
types.append(TypeRange(object_type, is_upper_bound=True))
elif isinstance(typ, Instance) and typ.type.fullname == "types.UnionType" and typ.args:
types.append(TypeRange(UnionType(typ.args), is_upper_bound=False))
elif isinstance(typ, AnyType):
types.append(TypeRange(typ, is_upper_bound=False))
else: # we didn't see an actual type, but rather a variable with unknown value
type_range = self.isinstance_type_range(typ)
if type_range is None:
return None
return types
type_ranges.append(type_range)
return type_ranges

def isinstance_type_range(self, typ: ProperType) -> TypeRange | None:
if isinstance(typ, FunctionLike) and typ.is_type_obj():
# If a type is generic, `isinstance` can only narrow its variables to Any.
any_parameterized = fill_typevars_with_any(typ.type_object())
# Tuples may have unattended type variables among their items
if isinstance(any_parameterized, TupleType):
erased_type = erase_typevars(any_parameterized)
else:
erased_type = any_parameterized
return TypeRange(erased_type, is_upper_bound=False)
elif isinstance(typ, TypeType):
# Type[A] means "any type that is a subtype of A" rather than "precisely type A"
# we indicate this by setting is_upper_bound flag
is_upper_bound = True
if isinstance(typ.item, NoneType):
# except for Type[None], because "'NoneType' is not an acceptable base type"
is_upper_bound = False
return TypeRange(typ.item, is_upper_bound=is_upper_bound)
elif isinstance(typ, Instance) and typ.type.fullname == "builtins.type":
object_type = Instance(typ.type.mro[-1], [])
return TypeRange(object_type, is_upper_bound=True)
elif isinstance(typ, Instance) and typ.type.fullname == "types.UnionType" and typ.args:
return TypeRange(UnionType(typ.args), is_upper_bound=False)
elif isinstance(typ, AnyType):
return TypeRange(typ, is_upper_bound=False)
else: # we didn't see an actual type, but rather a variable with unknown value
return None

def is_literal_enum(self, n: Expression) -> bool:
"""Returns true if this expression (with the given type context) is an Enum literal.
Expand Down
23 changes: 23 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import mypy.applytype
import mypy.constraints
import mypy.typeops
from mypy.checker_shared import TypeRange
from mypy.checker_state import checker_state
from mypy.erasetype import erase_type
from mypy.expandtype import (
Expand Down Expand Up @@ -255,6 +256,28 @@ def is_equivalent(
)


def is_same_type_ranges(
a: list[TypeRange],
b: list[TypeRange],
ignore_promotions: bool = True,
subtype_context: SubtypeContext | None = None,
) -> bool:
return len(a) == len(b) and all(
is_same_type_range(x, y, ignore_promotions, subtype_context) for x, y in zip(a, b)
)


def is_same_type_range(
a: TypeRange,
b: TypeRange,
ignore_promotions: bool = True,
subtype_context: SubtypeContext | None = None,
) -> bool:
return a.is_upper_bound == b.is_upper_bound and is_same_type(
a.item, b.item, ignore_promotions, subtype_context
)


def is_same_type(
a: Type, b: Type, ignore_promotions: bool = True, subtype_context: SubtypeContext | None = None
) -> bool:
Expand Down
67 changes: 63 additions & 4 deletions test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -2716,23 +2716,81 @@ if type(x) == type(y) == int:
reveal_type(y) # N: Revealed type is "builtins.int"
reveal_type(x) # N: Revealed type is "builtins.int"

z: Any
if int == type(z) == int:
reveal_type(z) # N: Revealed type is "builtins.int"

[case testTypeEqualsCheckUsingIs]
from typing import Any

y: Any
if type(y) is int:
reveal_type(y) # N: Revealed type is "builtins.int"

[case testTypeEqualsCheckUsingImplicitTypes]
from typing import Any

x: str
y: Any
z: object
if type(y) is type(x):
reveal_type(x) # N: Revealed type is "builtins.str"
reveal_type(y) # N: Revealed type is "builtins.str"

if type(x) is type(z):
reveal_type(x) # N: Revealed type is "builtins.str"
reveal_type(z) # N: Revealed type is "builtins.str"

[case testTypeEqualsCheckUsingDifferentSpecializedTypes]
from collections import defaultdict

x: defaultdict
y: dict
z: object
if type(x) is type(y) is type(z):
reveal_type(x) # N: Revealed type is "collections.defaultdict[Any, Any]"
reveal_type(y) # N: Revealed type is "collections.defaultdict[Any, Any]"
reveal_type(z) # N: Revealed type is "collections.defaultdict[Any, Any]"

[case testUnionTypeEquality]
from typing import Any, reveal_type
# flags: --warn-unreachable

x: Any = ()
if type(x) == (int, str):
reveal_type(x) # E: Statement is unreachable

[builtins fixtures/tuple.pyi]

[case testTypeIntersectionWithConcreteTypes]
class X: x = 1
class Y: y = 1
class Z(X, Y): ...

z = Z()
x: X = z
y: Y = z
if type(x) is type(y):
reveal_type(x) # N: Revealed type is "__main__.<subclass of "__main__.X" and "__main__.Y">"
reveal_type(y) # N: Revealed type is "__main__.<subclass of "__main__.Y" and "__main__.X">"
x.y + y.x

if isinstance(x, type(y)) and isinstance(y, type(x)):
reveal_type(x) # N: Revealed type is "__main__.<subclass of "__main__.X" and "__main__.Y">"
reveal_type(y) # N: Revealed type is "__main__.<subclass of "__main__.X" and "__main__.Y">"
x.y + y.x

[builtins fixtures/isinstance.pyi]

[case testTypeEqualsCheckUsingIsNonOverlapping]
# flags: --warn-unreachable
from typing import Union

y: str
if type(y) is int: # E: Subclass of "str" and "int" cannot exist: would have incompatible method signatures
y # E: Statement is unreachable
if type(y) is int:
y
else:
reveal_type(y) # N: Revealed type is "builtins.str"
[builtins fixtures/isinstance.pyi]

[case testTypeEqualsCheckUsingIsNonOverlappingChild-xfail]
# flags: --warn-unreachable
Expand Down Expand Up @@ -2761,12 +2819,13 @@ else:

[case testTypeEqualsMultipleTypesShouldntNarrow]
# make sure we don't do any narrowing if there are multiple types being compared
# flags: --warn-unreachable

from typing import Union

x: Union[int, str]
if type(x) == int == str:
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]"

Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ def f(t: Type[C]) -> None:
if type(t) is M:
reveal_type(t) # N: Revealed type is "type[__main__.C]"
else:
reveal_type(t) # N: Revealed type is "type[__main__.C]"
reveal_type(t) # N: Revealed type is "type[__main__.C]"
if type(t) is not M:
reveal_type(t) # N: Revealed type is "type[__main__.C]"
else:
Expand Down