Skip to content

Commit

Permalink
Make EqualsTester notice incorrect delegation (#3809)
Browse files Browse the repository at this point in the history
This extends EqualsTester to notice when someone incorrectly does not return NotImplemented in `__eq__` for objects of an unrecognized type.  Doing this can cause commutativity to fail when implementing other objects that should be equal to this object.  (It was worse in python 2, where depending on whether one inherited from object or not could cause this).

This also fixes places where Cirq code was doing this incorrectly.

Fixes #877

One final note is that the `eigen_gate_test` was changed because apparently Sympy does this incorrectly

```
>>> import sympy
>>> s = sympy.Symbol("a")
>>> class MyClass:
...   def __eq__(self, other):
...     return True
... 
>>> s.__eq__(MyClass())
False
```
  • Loading branch information
dabacon committed Mar 25, 2021
1 parent ce9fde6 commit 8fb3efe
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 16 deletions.
2 changes: 1 addition & 1 deletion cirq/contrib/routing/swap_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_logical_operations(self) -> Iterable['cirq.Operation']:

def __eq__(self, other) -> bool:
if not isinstance(other, type(self)):
return False
return NotImplemented
return self.circuit == other.circuit and self.initial_mapping == other.initial_mapping

@property
Expand Down
4 changes: 3 additions & 1 deletion cirq/google/ops/calibration_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['token'])

def __eq__(self, other) -> bool:
return isinstance(other, CalibrationTag) and self.token == other.token
if not isinstance(other, CalibrationTag):
return NotImplemented
return self.token == other.token

def __hash__(self) -> int:
return hash(self.token)
2 changes: 0 additions & 2 deletions cirq/ops/eigen_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ def test_eq():

eq.add_equality_group(CExpZinGate(2.5))
eq.add_equality_group(CExpZinGate(2.25))
eq.make_equality_group(lambda: sympy.Symbol('a'))
eq.add_equality_group(sympy.Symbol('b'))

eq.add_equality_group(ZGateDef(exponent=0.5, global_shift=0.0))
eq.add_equality_group(ZGateDef(exponent=-0.5, global_shift=0.0))
Expand Down
8 changes: 6 additions & 2 deletions cirq/study/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def __init__(self, *factors: Sweep) -> None:
self.factors = factors

def __eq__(self, other):
return isinstance(other, Product) and self.factors == other.factors
if not isinstance(other, Product):
return NotImplemented
return self.factors == other.factors

def __hash__(self):
return hash(tuple(self.factors))
Expand Down Expand Up @@ -279,7 +281,9 @@ def __init__(self, *sweeps: Sweep) -> None:
self.sweeps = sweeps

def __eq__(self, other):
return isinstance(other, Zip) and self.sweeps == other.sweeps
if not isinstance(other, Zip):
return NotImplemented
return self.sweeps == other.sweeps

def __hash__(self) -> int:
return hash(tuple(self.sweeps))
Expand Down
35 changes: 27 additions & 8 deletions cirq/testing/equals_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,14 @@ def _verify_equality_group(self, *group_items: Any):
assert same or v1 is not v2, f"{v1!r} isn't equal to itself!"
assert (
same
), "{!r} and {!r} can't be in the same equality group. They're not equal.".format(
v1, v2
)
), f"{v1!r} and {v2!r} can't be in the same equality group. They're not equal."

# Between-group items must be unequal.
for other_group in self._groups:
for v1, v2 in itertools.product(group_items, other_group):
assert not EqualsTester._eq_check(
v1, v2
), "{!r} and {!r} can't be in different equality groups. They're equal.".format(
v1, v2
)
), f"{v1!r} and {v2!r} can't be in different equality groups. They're equal."

# Check that group items hash to the same thing, or are all unhashable.
hashes = [hash(v) if isinstance(v, collections.abc.Hashable) else None for v in group_items]
Expand All @@ -89,8 +85,18 @@ def _verify_equality_group(self, *group_items: Any):
)
example = next(examples)
raise AssertionError(
'Items in the same group produced different hashes. '
'Example: hash({!r}) is {!r} but hash({!r}) is {!r}.'.format(*example)
"Items in the same group produced different hashes. "
f"Example: hash({example[0]!r}) is {example[1]!r} but "
f"hash({example[2]!r}) is {example[3]!r}."
)

# Test that the objects correctly returns NotImplemented when tested against classes
# that the object does not know the type of.
for v in group_items:
assert _TestsForNotImplemented(v) == v and v == _TestsForNotImplemented(v), (
"An item did not return NotImplemented when checking equality of this "
f"item against a different type than the item. Relevant item: {v!r}. "
"Common problem: returning NotImplementedError instead of NotImplemented. "
)

def add_equality_group(self, *group_items: Any):
Expand Down Expand Up @@ -144,3 +150,16 @@ def __ne__(self, other):

def __hash__(self):
return hash(_ClassUnknownToSubjects)


class _TestsForNotImplemented:
"""Used to test that objects return NotImplemented for equality with other types.
This class is equal to a specific instance or delegates by returning NotImplemented.
"""

def __init__(self, other):
self.other = other

def __eq__(self, other):
return True if other is self.other else NotImplemented
88 changes: 87 additions & 1 deletion cirq/testing/equals_tester_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(self, k, h):
self._h = h

def __eq__(self, other):
return isinstance(other, KeyHash) and self._k == other._k
if not isinstance(other, KeyHash):
return NotImplemented
return self._k == other._k

def __ne__(self, other):
return not self == other
Expand Down Expand Up @@ -250,3 +252,87 @@ def test_works_on_types():
eq.add_equality_group(object)
eq.add_equality_group(int)
eq.add_equality_group(object())


def test_returns_not_implemented_for_other_types():
# First we demonstrate an example of the problem.

# FirstClass is the class that is broken.
# It returns False when it should return NotImplemented when its __eq__ is called
# on a class it does not recognize.
class FirstClass:
def __init__(self, val):
self.val = val

def __eq__(self, other):
if not isinstance(other, FirstClass):
return False
return self.val == other.val

# So, for example, here is a class that we want to be equal to FirstClass.
class SecondClass:
def __init__(self, val):
self.val = val

def __eq__(self, other):
if isinstance(other, (FirstClass, SecondClass)):
return self.val == other.val
# Ignore coverage, this is just for illustrative purposes.
return NotImplemented # coverage: ignore

# But we see that this does not work because it fails commutativity of ==
assert SecondClass("a") == FirstClass("a")
assert FirstClass("a") != SecondClass("a")

# The problem is that in the second case FirstClass should return NotImplemented, which
# will then cause the == call to check whether SecondClass is equal to FirstClass.

# So if we had done this correctly we would have instead of FirstClass and SecondClass,
# ThirdClass and FourthClass, respectively.
class ThirdClass:
def __init__(self, val):
self.val = val

def __eq__(self, other):
if not isinstance(other, ThirdClass):
return NotImplemented
return self.val == other.val

class FourthClass:
def __init__(self, val):
self.val = val

def __eq__(self, other):
if isinstance(other, (ThirdClass, FourthClass)):
return self.val == other.val
# Ignore coverage, this is just for illustrative purposes.
return NotImplemented # coverage: ignore

# We see this is fixed:
assert ThirdClass("a") == FourthClass("a")
assert FourthClass("a") == ThirdClass("a")

# Now test that EqualsTester catches this.
eq = EqualsTester()

with pytest.raises(AssertionError, match="NotImplemented"):
eq.add_equality_group(FirstClass("a"), FirstClass("a"))

eq = EqualsTester()
eq.add_equality_group(ThirdClass("a"), ThirdClass("a"))


def test_not_implemented_error():
# Common bug is to return NotImplementedError instead of NotImplemented.
class NotImplementedErrorCase:
def __init__(self, val):
self.val = val

def __eq__(self, other):
if not isinstance(other, NotImplementedErrorCase):
return NotImplementedError
return self.val == other.val

eq = EqualsTester()
with pytest.raises(AssertionError, match="NotImplemented"):
eq.add_equality_group(NotImplementedErrorCase("a"), NotImplementedErrorCase("a"))
2 changes: 1 addition & 1 deletion cirq/work/observable_measurement_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def hex_str_to_ndarray(hexstr):

def __eq__(self, other):
if not isinstance(other, BitstringAccumulator):
return False
return NotImplemented

if (
self.max_setting != other.max_setting
Expand Down

0 comments on commit 8fb3efe

Please sign in to comment.