diff --git a/mypy/checker.py b/mypy/checker.py index 2e3258208e51..8614148c7298 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -33,7 +33,8 @@ from mypy.types import ( Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType, Instance, NoneTyp, ErrorType, strip_type, - UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType + UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, + true_only, false_only ) from mypy.sametypes import is_same_type from mypy.messages import MessageBuilder @@ -2447,11 +2448,16 @@ def find_isinstance_check(node: Node, if is_not: if_vars, else_vars = else_vars, if_vars return if_vars, else_vars - elif isinstance(node, RefExpr) and experiments.STRICT_OPTIONAL: - # The type could be falsy, so we can't deduce anything new about the else branch + elif isinstance(node, RefExpr): + # Restrict the type of the variable to True-ish/False-ish in the if and else branches + # respectively vartype = type_map[node] - _, if_vars = conditional_type_map(node, vartype, NoneTyp(), weak=weak) - return if_vars, {} + if_type = true_only(vartype) + else_type = false_only(vartype) + ref = node # type: Node + if_map = {ref: if_type} if not isinstance(if_type, UninhabitedType) else None + else_map = {ref: else_type} if not isinstance(else_type, UninhabitedType) else None + return if_map, else_map elif isinstance(node, OpExpr) and node.op == 'and': left_if_vars, left_else_vars = find_isinstance_check( node.left, diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 2047c8308bb0..c34e7550b15a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5,7 +5,8 @@ from mypy.types import ( Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef, TupleType, Instance, TypeVarId, TypeVarType, ErasedType, UnionType, - PartialType, DeletedType, UnboundType, UninhabitedType, TypeType + PartialType, DeletedType, UnboundType, UninhabitedType, TypeType, + true_only, false_only ) from mypy.nodes import ( NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, @@ -1094,22 +1095,20 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: ctx = self.chk.type_context[-1] left_type = self.accept(e.left, ctx) + assert e.op in ('and', 'or') # Checked by visit_op_expr + if e.op == 'and': right_map, left_map = \ mypy.checker.find_isinstance_check(e.left, self.chk.type_map, self.chk.typing_mode_weak()) + restricted_left_type = false_only(left_type) + result_is_left = not left_type.can_be_true elif e.op == 'or': left_map, right_map = \ mypy.checker.find_isinstance_check(e.left, self.chk.type_map, self.chk.typing_mode_weak()) - else: - left_map = None - right_map = None - - if left_map and e.left in left_map: - # The type of expressions in left_map is the type they'll have if - # the left operand is the result of the operator. - left_type = left_map[e.left] + restricted_left_type = true_only(left_type) + result_is_left = not left_type.can_be_false with self.chk.binder.frame_context(): if right_map: @@ -1121,13 +1120,23 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: self.check_usable_type(left_type, context) self.check_usable_type(right_type, context) - # If either of the type maps is None that means that result cannot happen. - # If both of the type maps are None we just have no information. - if left_map is not None and right_map is None: + if right_map is None: + # The boolean expression is statically known to be the left value + assert left_map is not None # find_isinstance_check guarantees this return left_type - elif left_map is None and right_map is not None: + if left_map is None: + # The boolean expression is statically known to be the right value + assert right_map is not None # find_isinstance_check guarantees this return right_type - return UnionType.make_simplified_union([left_type, right_type]) + + if isinstance(restricted_left_type, UninhabitedType): + # The left operand can never be the result + return right_type + elif result_is_left: + # The left operand is always the result + return left_type + else: + return UnionType.make_simplified_union([restricted_left_type, right_type]) def check_list_multiply(self, e: OpExpr) -> Type: """Type check an expression of form '[...] * e'. diff --git a/mypy/join.py b/mypy/join.py index d6ca21448e41..b5de3fd1bafe 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -6,7 +6,7 @@ Type, AnyType, NoneTyp, Void, TypeVisitor, Instance, UnboundType, ErrorType, TypeVarType, CallableType, TupleType, ErasedType, TypeList, UnionType, FunctionLike, Overloaded, PartialType, DeletedType, - UninhabitedType, TypeType + UninhabitedType, TypeType, true_or_false ) from mypy.maptype import map_instance_to_supertype from mypy.subtypes import is_subtype, is_equivalent, is_subtype_ignoring_tvars @@ -17,6 +17,11 @@ def join_simple(declaration: Type, s: Type, t: Type) -> Type: """Return a simple least upper bound given the declared type.""" + if (s.can_be_true, s.can_be_false) != (t.can_be_true, t.can_be_false): + # if types are restricted in different ways, use the more general versions + s = true_or_false(s) + t = true_or_false(t) + if isinstance(s, AnyType): return s @@ -60,6 +65,11 @@ def join_types(s: Type, t: Type) -> Type: If the join does not exist, return an ErrorType instance. """ + if (s.can_be_true, s.can_be_false) != (t.can_be_true, t.can_be_false): + # if types are restricted in different ways, use the more general versions + s = true_or_false(s) + t = true_or_false(t) + if isinstance(s, AnyType): return s diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 73154e2cd867..b2b6d0ab438a 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -3,15 +3,16 @@ from typing import List from mypy.myunit import ( - Suite, assert_equal, assert_true, assert_false + Suite, assert_equal, assert_true, assert_false, assert_type ) from mypy.erasetype import erase_type from mypy.expandtype import expand_type -from mypy.join import join_types +from mypy.join import join_types, join_simple from mypy.meet import meet_types from mypy.types import ( UnboundType, AnyType, Void, CallableType, TupleType, TypeVarDef, Type, - Instance, NoneTyp, ErrorType, Overloaded, TypeType, + Instance, NoneTyp, ErrorType, Overloaded, TypeType, UnionType, UninhabitedType, + true_only, false_only ) from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, CONTRAVARIANT, INVARIANT, COVARIANT from mypy.subtypes import is_subtype, is_more_precise, is_proper_subtype @@ -232,6 +233,95 @@ def test_is_proper_subtype_invariance(self): assert_false(is_proper_subtype(fx.gb, fx.ga)) assert_false(is_proper_subtype(fx.ga, fx.gb)) + # can_be_true / can_be_false + + def test_empty_tuple_always_false(self): + tuple_type = self.tuple() + assert_true(tuple_type.can_be_false) + assert_false(tuple_type.can_be_true) + + def test_nonempty_tuple_always_true(self): + tuple_type = self.tuple(AnyType(), AnyType()) + assert_true(tuple_type.can_be_true) + assert_false(tuple_type.can_be_false) + + def test_union_can_be_true_if_any_true(self): + union_type = UnionType([self.fx.a, self.tuple()]) + assert_true(union_type.can_be_true) + + def test_union_can_not_be_true_if_none_true(self): + union_type = UnionType([self.tuple(), self.tuple()]) + assert_false(union_type.can_be_true) + + def test_union_can_be_false_if_any_false(self): + union_type = UnionType([self.fx.a, self.tuple()]) + assert_true(union_type.can_be_false) + + def test_union_can_not_be_false_if_none_false(self): + union_type = UnionType([self.tuple(self.fx.a), self.tuple(self.fx.d)]) + assert_false(union_type.can_be_false) + + # true_only / false_only + + def test_true_only_of_false_type_is_uninhabited(self): + to = true_only(NoneTyp()) + assert_type(UninhabitedType, to) + + def test_true_only_of_true_type_is_idempotent(self): + always_true = self.tuple(AnyType()) + to = true_only(always_true) + assert_true(always_true is to) + + def test_true_only_of_instance(self): + to = true_only(self.fx.a) + assert_equal(str(to), "A") + assert_true(to.can_be_true) + assert_false(to.can_be_false) + assert_type(Instance, to) + # The original class still can be false + assert_true(self.fx.a.can_be_false) + + def test_true_only_of_union(self): + tup_type = self.tuple(AnyType()) + # Union of something that is unknown, something that is always true, something + # that is always false + union_type = UnionType([self.fx.a, tup_type, self.tuple()]) + to = true_only(union_type) + assert_equal(len(to.items), 2) + assert_true(to.items[0].can_be_true) + assert_false(to.items[0].can_be_false) + assert_true(to.items[1] is tup_type) + + def test_false_only_of_true_type_is_uninhabited(self): + fo = false_only(self.tuple(AnyType())) + assert_type(UninhabitedType, fo) + + def test_false_only_of_false_type_is_idempotent(self): + always_false = NoneTyp() + fo = false_only(always_false) + assert_true(always_false is fo) + + def test_false_only_of_instance(self): + fo = false_only(self.fx.a) + assert_equal(str(fo), "A") + assert_false(fo.can_be_true) + assert_true(fo.can_be_false) + assert_type(Instance, fo) + # The original class still can be true + assert_true(self.fx.a.can_be_true) + + def test_false_only_of_union(self): + tup_type = self.tuple() + # Union of something that is unknown, something that is always true, something + # that is always false + union_type = UnionType([self.fx.a, self.tuple(AnyType()), tup_type]) + assert_equal(len(union_type.items), 3) + fo = false_only(union_type) + assert_equal(len(fo.items), 2) + assert_false(fo.items[0].can_be_true) + assert_true(fo.items[0].can_be_false) + assert_true(fo.items[1] is tup_type) + # Helpers def tuple(self, *a): @@ -343,6 +433,22 @@ def test_any_type(self): self.callable(self.fx.a, self.fx.b)]: self.assert_join(t, self.fx.anyt, self.fx.anyt) + def test_mixed_truth_restricted_type_simple(self): + # join_simple against differently restricted truthiness types drops restrictions. + true_a = true_only(self.fx.a) + false_o = false_only(self.fx.o) + j = join_simple(self.fx.o, true_a, false_o) + assert_true(j.can_be_true) + assert_true(j.can_be_false) + + def test_mixed_truth_restricted_type(self): + # join_types against differently restricted truthiness types drops restrictions. + true_any = true_only(AnyType()) + false_o = false_only(self.fx.o) + j = join_types(true_any, false_o) + assert_true(j.can_be_true) + assert_true(j.can_be_false) + def test_other_mixed_types(self): # In general, joining unrelated types produces object. for t1 in [self.fx.a, self.fx.t, self.tuple(), diff --git a/mypy/types.py b/mypy/types.py index 59ed4f9b09e8..d582079732b9 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1,6 +1,7 @@ """Classes for representing mypy types.""" from abc import abstractmethod +import copy from typing import ( Any, TypeVar, Dict, List, Tuple, cast, Generic, Set, Sequence, Optional, Union ) @@ -20,6 +21,8 @@ class Type(mypy.nodes.Context): """Abstract base class for all types.""" line = 0 + can_be_true = True + can_be_false = True def __init__(self, line: int = -1) -> None: self.line = line @@ -259,6 +262,7 @@ class Void(Type): the result type of calling such callable. """ + can_be_true = False source = '' # May be None; function that generated this value def __init__(self, source: str = None, line: int = -1) -> None: @@ -294,6 +298,9 @@ class UninhabitedType(Type): is_subtype(UninhabitedType, T) = True """ + can_be_true = False + can_be_false = False + def __init__(self, line: int = -1) -> None: super().__init__(line) @@ -327,6 +334,8 @@ class NoneTyp(Type): of a function, where 'None' means Void. """ + can_be_true = False + def __init__(self, is_ret_type: bool = False, line: int = -1) -> None: super().__init__(line) self.is_ret_type = is_ret_type @@ -480,6 +489,8 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarType': class FunctionLike(Type): """Abstract base class for function types.""" + can_be_false = False + @abstractmethod def is_type_obj(self) -> bool: pass @@ -741,6 +752,8 @@ def __init__(self, items: List[Type], fallback: Instance, line: int = -1, self.items = items self.fallback = fallback self.implicit = implicit + self.can_be_true = len(self.items) > 0 + self.can_be_false = len(self.items) == 0 super().__init__(line) def length(self) -> int: @@ -787,6 +800,8 @@ class UnionType(Type): def __init__(self, items: List[Type], line: int = -1) -> None: self.items = items + self.can_be_true = any(item.can_be_true for item in items) + self.can_be_false = any(item.can_be_false for item in items) super().__init__(line) @staticmethod @@ -817,10 +832,20 @@ def make_simplified_union(items: List[Type], line: int = -1) -> Type: from mypy.subtypes import is_subtype removed = set() # type: Set[int] - for i in range(len(items)): - if any(is_subtype(items[i], items[j]) for j in range(len(items)) - if j not in removed and j != i): - removed.add(i) + for i, ti in enumerate(items): + if i in removed: continue + # Keep track of the truishness info for deleted subtypes which can be relevant + cbt = cbf = False + for j, tj in enumerate(items): + if i != j and is_subtype(tj, ti): + removed.add(j) + cbt = cbt or tj.can_be_true + cbf = cbf or tj.can_be_false + # if deleted subtypes had more general truthiness, use that + if not ti.can_be_true and cbt: + items[i] = true_or_false(ti) + elif not ti.can_be_false and cbf: + items[i] = true_or_false(ti) simplified_set = [items[i] for i in range(len(items)) if i not in removed] return UnionType.make_union(simplified_set) @@ -1394,3 +1419,60 @@ def is_named_instance(t: Type, fullname: str) -> bool: return (isinstance(t, Instance) and t.type is not None and t.type.fullname() == fullname) + + +def copy_type(t: Type) -> Type: + """ + Build a copy of the type; used to mutate the copy with truthiness information + """ + return copy.copy(t) + + +def true_only(t: Type) -> Type: + """ + Restricted version of t with only True-ish values + """ + if not t.can_be_true: + # All values of t are False-ish, so there are no true values in it + return UninhabitedType(line=t.line) + elif not t.can_be_false: + # All values of t are already True-ish, so true_only is idempotent in this case + return t + elif isinstance(t, UnionType): + # The true version of a union type is the union of the true versions of its components + new_items = [true_only(item) for item in t.items] + return UnionType.make_simplified_union(new_items, line=t.line) + else: + new_t = copy_type(t) + new_t.can_be_false = False + return new_t + + +def false_only(t: Type) -> Type: + """ + Restricted version of t with only False-ish values + """ + if not t.can_be_false: + # All values of t are True-ish, so there are no false values in it + return UninhabitedType(line=t.line) + elif not t.can_be_true: + # All values of t are already False-ish, so false_only is idempotent in this case + return t + elif isinstance(t, UnionType): + # The false version of a union type is the union of the false versions of its components + new_items = [false_only(item) for item in t.items] + return UnionType.make_simplified_union(new_items, line=t.line) + else: + new_t = copy_type(t) + new_t.can_be_true = False + return new_t + + +def true_or_false(t: Type) -> Type: + """ + Unrestricted version of t with both True-ish and False-ish values + """ + new_t = copy_type(t) + new_t.can_be_true = type(new_t).can_be_true + new_t.can_be_false = type(new_t).can_be_false + return new_t diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index ff5814b78de0..ccd4ee9ef402 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -300,6 +300,28 @@ b = a and b # E: Incompatible types in assignment (expression has type "Union[A, b = b or a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool") b = a or b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool") class A: pass + +[builtins fixtures/bool.pyi] + +[case testRestrictedTypeAnd] + +b = None # type: bool +i = None # type: str +j = not b and i +if j: + reveal_type(j) # E: Revealed type is 'builtins.str' + + +[builtins fixtures/bool.pyi] + +[case testRestrictedTypeOr] + +b = None # type: bool +i = None # type: str +j = b or i +if not j: + reveal_type(j) # E: Revealed type is 'builtins.str' + [builtins fixtures/bool.pyi] [case testNonBooleanOr] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index afbd489c07ef..2c699ca61fc8 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -408,7 +408,7 @@ class A(Generic[T]): if s: return p_s_s # E: Incompatible return value type (got p[S, S], expected p[S, T]) p_t_t = None # type: p[T, T] - if s: + if t: return p_t_t # E: Incompatible return value type (got p[T, T], expected p[S, T]) t = t s = s diff --git a/test-data/unit/check-modules.test b/test-data/unit/check-modules.test index d7001869bec8..1fe861e40dc6 100644 --- a/test-data/unit/check-modules.test +++ b/test-data/unit/check-modules.test @@ -61,7 +61,7 @@ class Bad: pass [case testImportWithinBlock] import typing -if None: +if 1: import m m.a = m.b # E: Incompatible types in assignment (expression has type "B", variable has type "A") m.a = m.a diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index 4093779f3b41..b0e9b26183e5 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -96,15 +96,17 @@ def f() -> Iterator[int]: [case testIfStatement] a = None # type: A +a2 = None # type: A +a3 = None # type: A b = None # type: bool if a: - a = b # Fail -elif a: - a = b # Fail -elif a: - a = b # Fail + a = b # E: Incompatible types in assignment (expression has type "bool", variable has type "A") +elif a2: + a = b # E: Incompatible types in assignment (expression has type "bool", variable has type "A") +elif a3: + a = b # E: Incompatible types in assignment (expression has type "bool", variable has type "A") else: - a = b # Fail + a = b # E: Incompatible types in assignment (expression has type "bool", variable has type "A") if b: pass elif b: @@ -114,11 +116,6 @@ if b: class A: pass [builtins fixtures/bool.pyi] -[out] -main:5: error: Incompatible types in assignment (expression has type "bool", variable has type "A") -main:7: error: Incompatible types in assignment (expression has type "bool", variable has type "A") -main:9: error: Incompatible types in assignment (expression has type "bool", variable has type "A") -main:11: error: Incompatible types in assignment (expression has type "bool", variable has type "A") -- Loops