Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweaks to --strict-equality based on user feedback #6674

Merged
merged 8 commits into from Apr 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/source/command_line.rst
Expand Up @@ -396,15 +396,17 @@ of the above sections.

.. code-block:: python

from typing import Text
from typing import List, Text

text: Text
if b'some bytes' in text: # Error: non-overlapping check!
items: List[int]
if 'some string' in items: # Error: non-overlapping container check!
...
if text != b'other bytes': # Error: non-overlapping check!

text: Text
if text != b'other bytes': # Error: non-overlapping equality check!
...

assert text is not None # OK, this special case is allowed.
assert text is not None # OK, check against None is allowed as a special case.

``--strict``
This flag mode enables all optional error checking flags. You can see the
Expand Down
59 changes: 51 additions & 8 deletions mypy/checkexpr.py
Expand Up @@ -1938,7 +1938,8 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
self.msg.unsupported_operand_types('in', left_type, right_type, e)
# Only show dangerous overlap if there are no other errors.
elif (not local_errors.is_errors() and cont_type and
self.dangerous_comparison(left_type, cont_type)):
self.dangerous_comparison(left_type, cont_type,
original_container=right_type)):
self.msg.dangerous_comparison(left_type, cont_type, 'container', e)
else:
self.msg.add_errors(local_errors)
Expand All @@ -1951,8 +1952,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
# testCustomEqCheckStrictEquality for an example.
if self.msg.errors.total_errors() == err_count and operator in ('==', '!='):
right_type = self.accept(right)
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)
if (not custom_equality_method(left_type) and
not custom_equality_method(right_type)):
# We suppress the error if there is a custom __eq__() method on either
# side. User defined (or even standard library) classes can define this
# to return True for comparisons between non-overlapping types.
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)

elif operator == 'is' or operator == 'is not':
right_type = self.accept(right) # validate the right operand
Expand All @@ -1974,9 +1980,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
assert result is not None
return result

def dangerous_comparison(self, left: Type, right: Type) -> bool:
def dangerous_comparison(self, left: Type, right: Type,
original_container: Optional[Type] = None) -> bool:
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.

The original_container is the original container type for 'in' checks
(and None for equality checks).

Rules:
* X and None are overlapping even in strict-optional mode. This is to allow
'assert x is not None' for x defined as 'x = None # type: str' in class body
Expand All @@ -1985,9 +1995,7 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool:
non-overlapping, although technically None is overlap, it is most
likely an error.
* Any overlaps with everything, i.e. always safe.
* Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0
are errors. This is mostly needed for bytes vs unicode, and
int vs float are added just for consistency.
* Special case: b'abc' in b'cde' is safe.
"""
if not self.chk.options.strict_equality:
return False
Expand All @@ -1996,7 +2004,12 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool:
if isinstance(left, UnionType) and isinstance(right, UnionType):
left = remove_optional(left)
right = remove_optional(right)
return not is_overlapping_types(left, right, ignore_promotions=True)
if (original_container and has_bytes_component(original_container) and
has_bytes_component(left)):
# We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc'
# return True (and we want to show the error only if the check can _never_ be True).
return False
return not is_overlapping_types(left, right, ignore_promotions=False)

def get_operator_method(self, op: str) -> str:
if op == '/' and self.chk.options.python_version[0] == 2:
Expand Down Expand Up @@ -3809,3 +3822,33 @@ def is_expr_literal_type(node: Expression) -> bool:
underlying = node.node
return isinstance(underlying, TypeAlias) and isinstance(underlying.target, LiteralType)
return False


def custom_equality_method(typ: Type) -> bool:
"""Does this type have a custom __eq__() method?"""
if isinstance(typ, Instance):
method = typ.type.get_method('__eq__')
if method and method.info:
return not method.info.fullname().startswith('builtins.')
return False
if isinstance(typ, UnionType):
return any(custom_equality_method(t) for t in typ.items)
if isinstance(typ, TupleType):
return custom_equality_method(tuple_fallback(typ))
if isinstance(typ, CallableType) and typ.is_type_obj():
# Look up __eq__ on the metaclass for class objects.
return custom_equality_method(typ.fallback)
if isinstance(typ, AnyType):
# Avoid false positives in uncertain cases.
return True
# TODO: support other types (see ExpressionChecker.has_member())?
return False


def has_bytes_component(typ: Type) -> bool:
"""Is this the builtin bytes type, or a union that contains it?"""
if isinstance(typ, UnionType):
return any(has_bytes_component(t) for t in typ.items)
if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes':
return True
return False
80 changes: 79 additions & 1 deletion test-data/unit/check-expressions.test
Expand Up @@ -2024,7 +2024,23 @@ cb: Union[Container[A], Container[B]]
[builtins fixtures/bool.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityNoPromote]
[case testStrictEqualityBytesSpecial]
# flags: --strict-equality
b'abc' in b'abcde'
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityBytesSpecialUnion]
# flags: --strict-equality
from typing import Union
x: Union[bytes, str]

b'abc' in x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b'a' in 'b' fails at runtime. Should this generate an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this is independent of this flag. The error message says "Non-overlapping container check ..." while in this example the check may return True. I think this can be tightened in typeshed, we can just define str.__contains__ as accepting str, because 42 in 'b' etc. all fail as well at runtime.

x in b'abc'
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityNoPromotePy3]
# flags: --strict-equality
'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes")
b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str")
Expand All @@ -2035,6 +2051,16 @@ x != y # E: Non-overlapping equality check (left operand type: "str", right ope
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityOkPromote]
# flags: --strict-equality
from typing import Container
c: Container[int]

1 == 1.0 # OK
1.0 in c # OK
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityAny]
# flags: --strict-equality
from typing import Any, Container
Expand Down Expand Up @@ -2086,6 +2112,58 @@ class B:
A() == B() # E: Unsupported operand types for == ("A" and "B")
[builtins fixtures/bool.pyi]

[case testCustomEqCheckStrictEqualityOKInstance]
# flags: --strict-equality
class A:
def __eq__(self, other: object) -> bool:
...
class B:
def __eq__(self, other: object) -> bool:
...

A() == int() # OK
int() != B() # OK
[builtins fixtures/bool.pyi]

[case testCustomEqCheckStrictEqualityOKUnion]
# flags: --strict-equality
from typing import Union
class A:
def __eq__(self, other: object) -> bool:
...

x: Union[A, str]
x == int()
[builtins fixtures/bool.pyi]

[case testCustomEqCheckStrictEqualityTuple]
# flags: --strict-equality
from typing import NamedTuple

class Base(NamedTuple):
attr: int

class Custom(Base):
def __eq__(self, other: object) -> bool: ...

Base(int()) == int() # E: Non-overlapping equality check (left operand type: "Base", right operand type: "int")
Base(int()) == tuple()
Custom(int()) == int()
[builtins fixtures/bool.pyi]

[case testCustomEqCheckStrictEqualityMeta]
# flags: --strict-equality
class CustomMeta(type):
def __eq__(self, other: object) -> bool: ...

class Normal: ...
class Custom(metaclass=CustomMeta): ...

Normal == int() # E: Non-overlapping equality check (left operand type: "Type[Normal]", right operand type: "int")
Normal == Normal
Custom == int()
[builtins fixtures/bool.pyi]

[case testCustomContainsCheckStrictEquality]
# flags: --strict-equality
class A:
Expand Down
5 changes: 4 additions & 1 deletion test-data/unit/fixtures/primitives.pyi
Expand Up @@ -23,7 +23,10 @@ class str(Sequence[str]):
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> str: pass
def format(self, *args) -> str: pass
class bytes: pass
class bytes(Sequence[int]):
def __iter__(self) -> Iterator[int]: pass
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> int: pass
class bytearray: pass
class tuple(Generic[T]): pass
class function: pass
Expand Down