Skip to content

Commit

Permalink
Return NotImplemented, not False, upon rich comparison with unknown type
Browse files Browse the repository at this point in the history
Also, use _sympifyit decorator in rich comparison methods.

Closes sympy/sympy#13078 (test from sympy/sympy#13091)
  • Loading branch information
skirpichev committed Aug 6, 2017
1 parent f47d598 commit ee86772
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 102 deletions.
12 changes: 4 additions & 8 deletions diofant/core/basic.py
Expand Up @@ -4,6 +4,7 @@

from .cache import cacheit
from .compatibility import iterable, ordered
from .decorators import _sympifyit
from .singleton import S
from .sympify import SympifyError, _sympify, sympify

Expand Down Expand Up @@ -181,6 +182,7 @@ def inner_key(arg):
args = len(args), tuple(inner_key(arg) for arg in args)
return self.class_key(), args, S.One.sort_key(), S.One

@_sympifyit('other', NotImplemented)
def __eq__(self, other):
"""Return a boolean indicating whether a == b on the basis of
their symbolic trees.
Expand All @@ -206,13 +208,7 @@ def __eq__(self, other):
return True

if type(self) != type(other):
try:
other = _sympify(other)
except SympifyError:
return False # diofant != other

if type(self) != type(other):
return False
return False

return self._hashable_content() == other._hashable_content()

Expand Down Expand Up @@ -945,7 +941,7 @@ def _has(self, pattern):

def _has_matcher(self):
"""Helper for .has()"""
return self.__eq__
return lambda x: self == x

def replace(self, query, value, map=False, simultaneous=True, exact=False):
"""Replace matching subexpressions of ``self`` with ``value``.
Expand Down
122 changes: 28 additions & 94 deletions diofant/core/numbers.py
Expand Up @@ -967,8 +967,8 @@ def __eq__(self, other):
return bool(mlib.mpf_eq(self._mpf_, ompf))
return False # Float != non-Number

@_sympifyit('other', NotImplemented)
def __gt__(self, other):
other = _sympify(other)
if isinstance(other, NumberSymbol):
return other.__le__(self)
if other.is_comparable:
Expand All @@ -978,8 +978,8 @@ def __gt__(self, other):
mlib.mpf_gt(self._mpf_, other._as_mpf_val(self._prec))))
return Expr.__gt__(self, other)

@_sympifyit('other', NotImplemented)
def __ge__(self, other):
other = _sympify(other)
if isinstance(other, NumberSymbol):
return other.__lt__(self)
if other.is_comparable:
Expand All @@ -989,19 +989,19 @@ def __ge__(self, other):
mlib.mpf_ge(self._mpf_, other._as_mpf_val(self._prec))))
return Expr.__ge__(self, other)

@_sympifyit('other', NotImplemented)
def __lt__(self, other):
other = _sympify(other)
if isinstance(other, NumberSymbol):
return other.__ge__(self)
if other.is_extended_real and other.is_number:
other = other.evalf()
if isinstance(other, Number) and other is not S.NaN:
return _sympify(bool(
mlib.mpf_lt(self._mpf_, other._as_mpf_val(self._prec))))
return _sympify(bool(mlib.mpf_lt(self._mpf_,
other._as_mpf_val(self._prec))))
return Expr.__lt__(self, other)

@_sympifyit('other', NotImplemented)
def __le__(self, other):
other = _sympify(other)
if isinstance(other, NumberSymbol):
return other.__gt__(self)
if other.is_extended_real and other.is_number:
Expand Down Expand Up @@ -1305,11 +1305,8 @@ def __int__(self):
return -(-p//q)
return p//q

@_sympifyit('other', NotImplemented)
def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # diofant != other --> not ==
if isinstance(other, NumberSymbol):
if other.is_irrational:
return False
Expand All @@ -1323,11 +1320,8 @@ def __eq__(self, other):
return mlib.mpf_eq(self._as_mpf_val(other._prec), other._mpf_)
return False

@_sympifyit('other', NotImplemented)
def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__le__(self)
expr = self
Expand All @@ -1341,11 +1335,8 @@ def __gt__(self, other):
expr, other = Integer(self.p), self.q*other
return Expr.__gt__(expr, other)

@_sympifyit('other', NotImplemented)
def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__lt__(self)
expr = self
Expand All @@ -1359,11 +1350,8 @@ def __ge__(self, other):
expr, other = Integer(self.p), self.q*other
return Expr.__ge__(expr, other)

@_sympifyit('other', NotImplemented)
def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
if isinstance(other, NumberSymbol):
return other.__ge__(self)
expr = self
Expand All @@ -1377,11 +1365,8 @@ def __lt__(self, other):
expr, other = Integer(self.p), self.q*other
return Expr.__lt__(expr, other)

@_sympifyit('other', NotImplemented)
def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
expr = self
if isinstance(other, NumberSymbol):
return other.__gt__(self)
Expand Down Expand Up @@ -1606,38 +1591,26 @@ def __eq__(self, other):
return (self.p == other.p)
return Rational.__eq__(self, other)

@_sympifyit('other', NotImplemented)
def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
if isinstance(other, Integer):
return _sympify(self.p > other.p)
return Rational.__gt__(self, other)

@_sympifyit('other', NotImplemented)
def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
if isinstance(other, Integer):
return _sympify(self.p < other.p)
return Rational.__lt__(self, other)

@_sympifyit('other', NotImplemented)
def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
if isinstance(other, Integer):
return _sympify(self.p >= other.p)
return Rational.__ge__(self, other)

@_sympifyit('other', NotImplemented)
def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
if isinstance(other, Integer):
return _sympify(self.p <= other.p)
return Rational.__le__(self, other)
Expand Down Expand Up @@ -2391,20 +2364,14 @@ def __hash__(self):
def __eq__(self, other):
return other is S.Infinity

@_sympifyit('other', NotImplemented)
def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
if other.is_extended_real:
return S.false
return Expr.__lt__(self, other)

@_sympifyit('other', NotImplemented)
def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
if other.is_extended_real:
if other.is_finite or other is S.NegativeInfinity:
return S.false
Expand All @@ -2414,11 +2381,8 @@ def __le__(self, other):
return S.true
return Expr.__le__(self, other)

@_sympifyit('other', NotImplemented)
def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
if other.is_extended_real:
if other.is_finite or other is S.NegativeInfinity:
return S.true
Expand All @@ -2428,11 +2392,8 @@ def __gt__(self, other):
return S.false
return Expr.__gt__(self, other)

@_sympifyit('other', NotImplemented)
def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
if other.is_extended_real:
return S.true
return Expr.__ge__(self, other)
Expand Down Expand Up @@ -2586,11 +2547,8 @@ def __hash__(self):
def __eq__(self, other):
return other is S.NegativeInfinity

@_sympifyit('other', NotImplemented)
def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
if other.is_extended_real:
if other.is_finite or other is S.Infinity:
return S.true
Expand All @@ -2600,29 +2558,20 @@ def __lt__(self, other):
return S.false
return Expr.__lt__(self, other)

@_sympifyit('other', NotImplemented)
def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
if other.is_extended_real:
return S.true
return Expr.__le__(self, other)

@_sympifyit('other', NotImplemented)
def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
if other.is_extended_real:
return S.false
return Expr.__gt__(self, other)

@_sympifyit('other', NotImplemented)
def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
if other.is_extended_real:
if other.is_finite or other is S.Infinity:
return S.false
Expand Down Expand Up @@ -2815,52 +2764,37 @@ def approximation_interval(self, number_cls):
def _eval_evalf(self, prec):
return Float._new(self._as_mpf_val(prec), prec)

@_sympifyit('other', NotImplemented)
def __eq__(self, other):
try:
other = _sympify(other)
except SympifyError:
return False # diofant != other --> not ==
if self is other:
return True
if isinstance(other, Number) and self.is_irrational:
return False

return False # NumberSymbol != non-(Number|self)

@_sympifyit('other', NotImplemented)
def __lt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s < %s" % (self, other))
if self is other:
return S.false
return Expr.__lt__(self, other)

@_sympifyit('other', NotImplemented)
def __le__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s <= %s" % (self, other))
if self is other:
return S.true
return Expr.__le__(self, other)

@_sympifyit('other', NotImplemented)
def __gt__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s > %s" % (self, other))
r = _sympify((-self) < (-other))
if r in (S.true, S.false):
return r
else:
return Expr.__gt__(self, other)

@_sympifyit('other', NotImplemented)
def __ge__(self, other):
try:
other = _sympify(other)
except SympifyError:
raise TypeError("Invalid comparison %s >= %s" % (self, other))
r = _sympify((-self) <= (-other))
if r in (S.true, S.false):
return r
Expand Down

0 comments on commit ee86772

Please sign in to comment.