Skip to content

Commit

Permalink
Merge pull request #13091 from danielwe/improve-cmp
Browse files Browse the repository at this point in the history
Improve rich comparison methods, including fix for #13078
  • Loading branch information
asmeurer committed Aug 8, 2017
2 parents f69d195 + bfdaec2 commit f523547
Show file tree
Hide file tree
Showing 23 changed files with 177 additions and 56 deletions.
6 changes: 3 additions & 3 deletions sympy/core/basic.py
Expand Up @@ -313,7 +313,7 @@ def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # sympy != other
return NotImplemented

if type(self) != type(other):
return False
Expand All @@ -329,7 +329,7 @@ def __ne__(self, other):
but faster
"""
return not self.__eq__(other)
return not self == other

def dummy_eq(self, other, symbol=None):
"""
Expand Down Expand Up @@ -1180,7 +1180,7 @@ def _has(self, pattern):

def _has_matcher(self):
"""Helper for .has()"""
return self.__eq__
return lambda other: self == other

def replace(self, query, value, map=False, simultaneous=True, exact=False):
"""
Expand Down
4 changes: 2 additions & 2 deletions sympy/core/exprtools.py
Expand Up @@ -797,7 +797,7 @@ def __eq__(self, other): # Factors
return self.factors == other.factors

def __ne__(self, other): # Factors
return not self.__eq__(other)
return not self == other


class Term(object):
Expand Down Expand Up @@ -909,7 +909,7 @@ def __eq__(self, other): # Term
self.denom == other.denom)

def __ne__(self, other): # Term
return not self.__eq__(other)
return not self == other


def _gcd_terms(terms, isprimitive=False, fraction=True):
Expand Down
30 changes: 15 additions & 15 deletions sympy/core/numbers.py
Expand Up @@ -1258,7 +1258,7 @@ def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # sympy != other --> not ==
return NotImplemented
if isinstance(other, NumberSymbol):
if other.is_irrational:
return False
Expand All @@ -1276,15 +1276,15 @@ def __eq__(self, other):
return False # Float != non-Number

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__le__(self)
return other.__lt__(self)
if other.is_comparable:
other = other.evalf()
if isinstance(other, Number) and other is not S.NaN:
Expand All @@ -1298,7 +1298,7 @@ def __ge__(self, other):
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__lt__(self)
return other.__le__(self)
if other.is_comparable:
other = other.evalf()
if isinstance(other, Number) and other is not S.NaN:
Expand All @@ -1312,7 +1312,7 @@ def __lt__(self, other):
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__ge__(self)
return other.__gt__(self)
if other.is_real and other.is_number:
other = other.evalf()
if isinstance(other, Number) and other is not S.NaN:
Expand All @@ -1326,7 +1326,7 @@ def __le__(self, other):
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__gt__(self)
return other.__ge__(self)
if other.is_real and other.is_number:
other = other.evalf()
if isinstance(other, Number) and other is not S.NaN:
Expand Down Expand Up @@ -1719,7 +1719,7 @@ def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # sympy != other --> not ==
return NotImplemented
if isinstance(other, NumberSymbol):
if other.is_irrational:
return False
Expand All @@ -1734,15 +1734,15 @@ def __eq__(self, other):
return False

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__le__(self)
return other.__lt__(self)
expr = self
if isinstance(other, Number):
if isinstance(other, Rational):
Expand All @@ -1760,7 +1760,7 @@ def __ge__(self, other):
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__lt__(self)
return other.__le__(self)
expr = self
if isinstance(other, Number):
if isinstance(other, Rational):
Expand All @@ -1778,7 +1778,7 @@ def __lt__(self, other):
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__ge__(self)
return other.__gt__(self)
expr = self
if isinstance(other, Number):
if isinstance(other, Rational):
Expand All @@ -1797,7 +1797,7 @@ def __le__(self, other):
raise TypeError("Invalid comparison %s <= %s" % (self, other))
expr = self
if isinstance(other, NumberSymbol):
return other.__gt__(self)
return other.__ge__(self)
elif isinstance(other, Number):
if isinstance(other, Rational):
return _sympify(bool(self.p*other.q <= self.q*other.p))
Expand Down Expand Up @@ -2112,7 +2112,7 @@ def __eq__(self, other):
return Rational.__eq__(self, other)

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __gt__(self, other):
try:
Expand Down Expand Up @@ -3339,7 +3339,7 @@ def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # sympy != other --> not ==
return NotImplemented
if self is other:
return True
if isinstance(other, Number) and self.is_irrational:
Expand All @@ -3348,7 +3348,7 @@ def __eq__(self, other):
return False # NumberSymbol != non-(Number|self)

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __lt__(self, other):
try:
Expand Down
37 changes: 37 additions & 0 deletions sympy/core/tests/test_basic.py
Expand Up @@ -38,6 +38,43 @@ def test_equality():
assert Basic() != 0
assert not(Basic() == 0)

class Foo(object):
"""
Class that is unaware of Basic, and relies on both classes returning
the NotImplemented singleton for equivalence to evaluate to False.
"""

b = Basic()
foo = Foo()

assert b != foo
assert foo != b
assert not b == foo
assert not foo == b

class Bar(object):
"""
Class that considers itself equal to any instance of Basic, and relies
on Basic returning the NotImplemented singleton in order to achieve
a symmetric equivalence relation.
"""
def __eq__(self, other):
if isinstance(other, Basic):
return True
return NotImplemented

def __ne__(self, other):
return not self == other

bar = Bar()

assert b == bar
assert bar == b
assert not b != bar
assert not bar != b


def test_matches_basic():
instances = [Basic(b1, b1, b2), Basic(b1, b2, b1), Basic(b2, b1, b1),
Expand Down
84 changes: 84 additions & 0 deletions sympy/core/tests/test_numbers.py
Expand Up @@ -1653,3 +1653,87 @@ def test_mod_inverse():

def test_golden_ratio_rewrite_as_sqrt():
assert GoldenRatio.rewrite(sqrt) == S.Half + sqrt(5)*S.Half

def test_comparisons_with_unknown_type():
class Foo(object):
"""
Class that is unaware of Basic, and relies on both classes returning
the NotImplemented singleton for equivalence to evaluate to False.
"""

ni, nf, nr = Integer(3), Float(1.0), Rational(1, 3)
foo = Foo()

for n in ni, nf, nr, oo, -oo, zoo, nan:
assert n != foo
assert foo != n
assert not n == foo
assert not foo == n
raises(TypeError, lambda: n < foo)
raises(TypeError, lambda: foo > n)
raises(TypeError, lambda: n > foo)
raises(TypeError, lambda: foo < n)
raises(TypeError, lambda: n <= foo)
raises(TypeError, lambda: foo >= n)
raises(TypeError, lambda: n >= foo)
raises(TypeError, lambda: foo <= n)

class Bar(object):
"""
Class that considers itself equal to any instance of Number except
infinities and nans, and relies on sympy types returning the
NotImplemented singleton for symmetric equality relations.
"""
def __eq__(self, other):
if other in (oo, -oo, zoo, nan):
return False
if isinstance(other, Number):
return True
return NotImplemented

def __ne__(self, other):
return not self == other

bar = Bar()

for n in ni, nf, nr:
assert n == bar
assert bar == n
assert not n != bar
assert not bar != n

for n in oo, -oo, zoo, nan:
assert n != bar
assert bar != n
assert not n == bar
assert not bar == n

for n in ni, nf, nr, oo, -oo, zoo, nan:
raises(TypeError, lambda: n < bar)
raises(TypeError, lambda: bar > n)
raises(TypeError, lambda: n > bar)
raises(TypeError, lambda: bar < n)
raises(TypeError, lambda: n <= bar)
raises(TypeError, lambda: bar >= n)
raises(TypeError, lambda: n >= bar)
raises(TypeError, lambda: bar <= n)

def test_NumberSymbol_comparison():
rpi = Rational('905502432259640373/288230376151711744')
fpi = Float(float(pi))

assert (rpi == pi) == (pi == rpi)
assert (rpi != pi) == (pi != rpi)
assert (rpi < pi) == (pi > rpi)
assert (rpi <= pi) == (pi >= rpi)
assert (rpi > pi) == (pi < rpi)
assert (rpi >= pi) == (pi <= rpi)

assert (fpi == pi) == (pi == fpi)
assert (fpi != pi) == (pi != fpi)
assert (fpi < pi) == (pi > fpi)
assert (fpi <= pi) == (pi >= fpi)
assert (fpi > pi) == (pi < fpi)
assert (fpi >= pi) == (pi <= fpi)
2 changes: 1 addition & 1 deletion sympy/geometry/entity.py
Expand Up @@ -104,7 +104,7 @@ def __getnewargs__(self):

def __ne__(self, o):
"""Test inequality of two geometrical entities."""
return not self.__eq__(o)
return not self == o

def __new__(cls, *args, **kwargs):
# Points are sequences, but they should not
Expand Down
4 changes: 2 additions & 2 deletions sympy/physics/optics/medium.py
Expand Up @@ -183,10 +183,10 @@ def __lt__(self, other):
return self.refractive_index < other.refractive_index

def __gt__(self, other):
return not self.__lt__(other)
return not self < other

def __eq__(self, other):
return self.refractive_index == other.refractive_index

def __ne__(self, other):
return not self.__eq__(other)
return not self == other
2 changes: 1 addition & 1 deletion sympy/physics/vector/dyadic.py
Expand Up @@ -147,7 +147,7 @@ def __mul__(self, other):
return Dyadic(newlist)

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __neg__(self):
return self * -1
Expand Down
2 changes: 1 addition & 1 deletion sympy/physics/vector/frame.py
Expand Up @@ -70,7 +70,7 @@ def __eq__(self, other):
return False

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __hash__(self):
return tuple((self._id[0].__hash__(), self._id[1])).__hash__()
Expand Down
2 changes: 1 addition & 1 deletion sympy/physics/vector/vector.py
Expand Up @@ -166,7 +166,7 @@ def __mul__(self, other):
return Vector(newlist)

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def __neg__(self):
return self * -1
Expand Down
2 changes: 1 addition & 1 deletion sympy/polys/agca/modules.py
Expand Up @@ -250,7 +250,7 @@ def __eq__(self, om):
return self.eq(self.data, om.data)

def __ne__(self, om):
return not self.__eq__(om)
return not self == om

##########################################################################
## Free Modules ##########################################################
Expand Down
2 changes: 1 addition & 1 deletion sympy/polys/domains/domain.py
Expand Up @@ -343,7 +343,7 @@ def __eq__(self, other):

def __ne__(self, other):
"""Returns ``False`` if two domains are equivalent. """
return not self.__eq__(other)
return not self == other

def map(self, seq):
"""Rersively apply ``self`` to all elements of ``seq``. """
Expand Down
2 changes: 1 addition & 1 deletion sympy/polys/domains/expressiondomain.py
Expand Up @@ -119,7 +119,7 @@ def __eq__(f, g):
return f.ex == f.__class__(g).ex

def __ne__(f, g):
return not f.__eq__(g)
return not f == g

def __nonzero__(f):
return f.ex != 0
Expand Down
2 changes: 1 addition & 1 deletion sympy/polys/domains/pythonrational.py
Expand Up @@ -248,7 +248,7 @@ def __eq__(self, other):
return False

def __ne__(self, other):
return not self.__eq__(other)
return not self == other

def _cmp(self, other, op):
try:
Expand Down

0 comments on commit f523547

Please sign in to comment.