diff --git a/docs/source/command_line.rst b/docs/source/command_line.rst index 1aa1924eea9f..fc523b959bcf 100644 --- a/docs/source/command_line.rst +++ b/docs/source/command_line.rst @@ -396,15 +396,17 @@ of the above sections. .. code-block:: python - from typing import Text + from typing import List, Text - text: Text - if b'some bytes' in text: # Error: non-overlapping check! + items: List[int] + if 'some string' in items: # Error: non-overlapping container check! ... - if text != b'other bytes': # Error: non-overlapping check! + + text: Text + if text != b'other bytes': # Error: non-overlapping equality check! ... - assert text is not None # OK, this special case is allowed. + assert text is not None # OK, check against None is allowed as a special case. ``--strict`` This flag mode enables all optional error checking flags. You can see the diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index bd817ac733ae..adf1680f2945 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1938,7 +1938,8 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: self.msg.unsupported_operand_types('in', left_type, right_type, e) # Only show dangerous overlap if there are no other errors. elif (not local_errors.is_errors() and cont_type and - self.dangerous_comparison(left_type, cont_type)): + self.dangerous_comparison(left_type, cont_type, + original_container=right_type)): self.msg.dangerous_comparison(left_type, cont_type, 'container', e) else: self.msg.add_errors(local_errors) @@ -1951,8 +1952,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # testCustomEqCheckStrictEquality for an example. if self.msg.errors.total_errors() == err_count and operator in ('==', '!='): right_type = self.accept(right) - if self.dangerous_comparison(left_type, right_type): - self.msg.dangerous_comparison(left_type, right_type, 'equality', e) + if (not custom_equality_method(left_type) and + not custom_equality_method(right_type)): + # We suppress the error if there is a custom __eq__() method on either + # side. User defined (or even standard library) classes can define this + # to return True for comparisons between non-overlapping types. + if self.dangerous_comparison(left_type, right_type): + self.msg.dangerous_comparison(left_type, right_type, 'equality', e) elif operator == 'is' or operator == 'is not': right_type = self.accept(right) # validate the right operand @@ -1974,9 +1980,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: assert result is not None return result - def dangerous_comparison(self, left: Type, right: Type) -> bool: + def dangerous_comparison(self, left: Type, right: Type, + original_container: Optional[Type] = None) -> bool: """Check for dangerous non-overlapping comparisons like 42 == 'no'. + The original_container is the original container type for 'in' checks + (and None for equality checks). + Rules: * X and None are overlapping even in strict-optional mode. This is to allow 'assert x is not None' for x defined as 'x = None # type: str' in class body @@ -1985,9 +1995,7 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool: non-overlapping, although technically None is overlap, it is most likely an error. * Any overlaps with everything, i.e. always safe. - * Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0 - are errors. This is mostly needed for bytes vs unicode, and - int vs float are added just for consistency. + * Special case: b'abc' in b'cde' is safe. """ if not self.chk.options.strict_equality: return False @@ -1996,7 +2004,12 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool: if isinstance(left, UnionType) and isinstance(right, UnionType): left = remove_optional(left) right = remove_optional(right) - return not is_overlapping_types(left, right, ignore_promotions=True) + if (original_container and has_bytes_component(original_container) and + has_bytes_component(left)): + # We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc' + # return True (and we want to show the error only if the check can _never_ be True). + return False + return not is_overlapping_types(left, right, ignore_promotions=False) def get_operator_method(self, op: str) -> str: if op == '/' and self.chk.options.python_version[0] == 2: @@ -3809,3 +3822,33 @@ def is_expr_literal_type(node: Expression) -> bool: underlying = node.node return isinstance(underlying, TypeAlias) and isinstance(underlying.target, LiteralType) return False + + +def custom_equality_method(typ: Type) -> bool: + """Does this type have a custom __eq__() method?""" + if isinstance(typ, Instance): + method = typ.type.get_method('__eq__') + if method and method.info: + return not method.info.fullname().startswith('builtins.') + return False + if isinstance(typ, UnionType): + return any(custom_equality_method(t) for t in typ.items) + if isinstance(typ, TupleType): + return custom_equality_method(tuple_fallback(typ)) + if isinstance(typ, CallableType) and typ.is_type_obj(): + # Look up __eq__ on the metaclass for class objects. + return custom_equality_method(typ.fallback) + if isinstance(typ, AnyType): + # Avoid false positives in uncertain cases. + return True + # TODO: support other types (see ExpressionChecker.has_member())? + return False + + +def has_bytes_component(typ: Type) -> bool: + """Is this the builtin bytes type, or a union that contains it?""" + if isinstance(typ, UnionType): + return any(has_bytes_component(t) for t in typ.items) + if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes': + return True + return False diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index ee3ee71e1e88..001abb546d58 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2024,7 +2024,23 @@ cb: Union[Container[A], Container[B]] [builtins fixtures/bool.pyi] [typing fixtures/typing-full.pyi] -[case testStrictEqualityNoPromote] +[case testStrictEqualityBytesSpecial] +# flags: --strict-equality +b'abc' in b'abcde' +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + +[case testStrictEqualityBytesSpecialUnion] +# flags: --strict-equality +from typing import Union +x: Union[bytes, str] + +b'abc' in x +x in b'abc' +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + +[case testStrictEqualityNoPromotePy3] # flags: --strict-equality 'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes") b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str") @@ -2035,6 +2051,16 @@ x != y # E: Non-overlapping equality check (left operand type: "str", right ope [builtins fixtures/primitives.pyi] [typing fixtures/typing-full.pyi] +[case testStrictEqualityOkPromote] +# flags: --strict-equality +from typing import Container +c: Container[int] + +1 == 1.0 # OK +1.0 in c # OK +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + [case testStrictEqualityAny] # flags: --strict-equality from typing import Any, Container @@ -2086,6 +2112,58 @@ class B: A() == B() # E: Unsupported operand types for == ("A" and "B") [builtins fixtures/bool.pyi] +[case testCustomEqCheckStrictEqualityOKInstance] +# flags: --strict-equality +class A: + def __eq__(self, other: object) -> bool: + ... +class B: + def __eq__(self, other: object) -> bool: + ... + +A() == int() # OK +int() != B() # OK +[builtins fixtures/bool.pyi] + +[case testCustomEqCheckStrictEqualityOKUnion] +# flags: --strict-equality +from typing import Union +class A: + def __eq__(self, other: object) -> bool: + ... + +x: Union[A, str] +x == int() +[builtins fixtures/bool.pyi] + +[case testCustomEqCheckStrictEqualityTuple] +# flags: --strict-equality +from typing import NamedTuple + +class Base(NamedTuple): + attr: int + +class Custom(Base): + def __eq__(self, other: object) -> bool: ... + +Base(int()) == int() # E: Non-overlapping equality check (left operand type: "Base", right operand type: "int") +Base(int()) == tuple() +Custom(int()) == int() +[builtins fixtures/bool.pyi] + +[case testCustomEqCheckStrictEqualityMeta] +# flags: --strict-equality +class CustomMeta(type): + def __eq__(self, other: object) -> bool: ... + +class Normal: ... +class Custom(metaclass=CustomMeta): ... + +Normal == int() # E: Non-overlapping equality check (left operand type: "Type[Normal]", right operand type: "int") +Normal == Normal +Custom == int() +[builtins fixtures/bool.pyi] + [case testCustomContainsCheckStrictEquality] # flags: --strict-equality class A: diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index 796196fa08c6..f2c0cd03acfc 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -23,7 +23,10 @@ class str(Sequence[str]): def __contains__(self, other: object) -> bool: pass def __getitem__(self, item: int) -> str: pass def format(self, *args) -> str: pass -class bytes: pass +class bytes(Sequence[int]): + def __iter__(self) -> Iterator[int]: pass + def __contains__(self, other: object) -> bool: pass + def __getitem__(self, item: int) -> int: pass class bytearray: pass class tuple(Generic[T]): pass class function: pass