Skip to content

Commit

Permalink
Merge pull request #21699 from posita/posita/fix-integral-registration
Browse files Browse the repository at this point in the history
feat(core): make sympy.Integer compatible with Integral ABC
  • Loading branch information
oscarbenjamin committed Jul 8, 2021
2 parents 17f3457 + 35366c5 commit 7c4d578
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 1 deletion.
72 changes: 71 additions & 1 deletion sympy/core/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2422,6 +2422,76 @@ def __floordiv__(self, other):
def __rfloordiv__(self, other):
return Integer(Integer(other).p // self.p)

# These bitwise operations (__lshift__, __rlshift__, ..., __invert__) are defined
# for Integer only and not for general sympy expressions. This is to achieve
# compatibility with the numbers.Integral ABC which only defines these operations
# among instances of numbers.Integral. Therefore, these methods check explicitly for
# integer types rather than using sympify because they should not accept arbitrary
# symbolic expressions and there is no symbolic analogue of numbers.Integral's
# bitwise operations.
def __lshift__(self, other):
if isinstance(other, (int, Integer, numbers.Integral)):
return Integer(self.p << int(other))
else:
return NotImplemented

def __rlshift__(self, other):
if isinstance(other, (int, numbers.Integral)):
return Integer(int(other) << self.p)
else:
return NotImplemented

def __rshift__(self, other):
if isinstance(other, (int, Integer, numbers.Integral)):
return Integer(self.p >> int(other))
else:
return NotImplemented

def __rrshift__(self, other):
if isinstance(other, (int, numbers.Integral)):
return Integer(int(other) >> self.p)
else:
return NotImplemented

def __and__(self, other):
if isinstance(other, (int, Integer, numbers.Integral)):
return Integer(self.p & int(other))
else:
return NotImplemented

def __rand__(self, other):
if isinstance(other, (int, numbers.Integral)):
return Integer(int(other) & self.p)
else:
return NotImplemented

def __xor__(self, other):
if isinstance(other, (int, Integer, numbers.Integral)):
return Integer(self.p ^ int(other))
else:
return NotImplemented

def __rxor__(self, other):
if isinstance(other, (int, numbers.Integral)):
return Integer(int(other) ^ self.p)
else:
return NotImplemented

def __or__(self, other):
if isinstance(other, (int, Integer, numbers.Integral)):
return Integer(self.p | int(other))
else:
return NotImplemented

def __ror__(self, other):
if isinstance(other, (int, numbers.Integral)):
return Integer(int(other) | self.p)
else:
return NotImplemented

def __invert__(self):
return Integer(~self.p)

# Add sympify converters
converter[int] = Integer

Expand Down Expand Up @@ -4054,6 +4124,6 @@ def _register_classes():
numbers.Number.register(Number)
numbers.Real.register(Float)
numbers.Rational.register(Rational)
numbers.Rational.register(Integer)
numbers.Integral.register(Integer)

_register_classes()
125 changes: 125 additions & 0 deletions sympy/core/tests/test_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,127 @@ def test_powers_Float():
assert str((S('-1/10')**S('3/10')).n()) == str(Float(-.1)**(.3))


def test_lshift_Integer():
assert Integer(0) << Integer(2) == Integer(0)
assert Integer(0) << 2 == Integer(0)
assert 0 << Integer(2) == Integer(0)

assert Integer(0b11) << Integer(0) == Integer(0b11)
assert Integer(0b11) << 0 == Integer(0b11)
assert 0b11 << Integer(0) == Integer(0b11)

assert Integer(0b11) << Integer(2) == Integer(0b11 << 2)
assert Integer(0b11) << 2 == Integer(0b11 << 2)
assert 0b11 << Integer(2) == Integer(0b11 << 2)

assert Integer(-0b11) << Integer(2) == Integer(-0b11 << 2)
assert Integer(-0b11) << 2 == Integer(-0b11 << 2)
assert -0b11 << Integer(2) == Integer(-0b11 << 2)

raises(TypeError, lambda: Integer(2) << 0.0)
raises(TypeError, lambda: 0.0 << Integer(2))
raises(ValueError, lambda: Integer(1) << Integer(-1))


def test_rshift_Integer():
assert Integer(0) >> Integer(2) == Integer(0)
assert Integer(0) >> 2 == Integer(0)
assert 0 >> Integer(2) == Integer(0)

assert Integer(0b11) >> Integer(0) == Integer(0b11)
assert Integer(0b11) >> 0 == Integer(0b11)
assert 0b11 >> Integer(0) == Integer(0b11)

assert Integer(0b11) >> Integer(2) == Integer(0)
assert Integer(0b11) >> 2 == Integer(0)
assert 0b11 >> Integer(2) == Integer(0)

assert Integer(-0b11) >> Integer(2) == Integer(-1)
assert Integer(-0b11) >> 2 == Integer(-1)
assert -0b11 >> Integer(2) == Integer(-1)

assert Integer(0b1100) >> Integer(2) == Integer(0b1100 >> 2)
assert Integer(0b1100) >> 2 == Integer(0b1100 >> 2)
assert 0b1100 >> Integer(2) == Integer(0b1100 >> 2)

assert Integer(-0b1100) >> Integer(2) == Integer(-0b1100 >> 2)
assert Integer(-0b1100) >> 2 == Integer(-0b1100 >> 2)
assert -0b1100 >> Integer(2) == Integer(-0b1100 >> 2)

raises(TypeError, lambda: Integer(0b10) >> 0.0)
raises(TypeError, lambda: 0.0 >> Integer(2))
raises(ValueError, lambda: Integer(1) >> Integer(-1))


def test_and_Integer():
assert Integer(0b01010101) & Integer(0b10101010) == Integer(0)
assert Integer(0b01010101) & 0b10101010 == Integer(0)
assert 0b01010101 & Integer(0b10101010) == Integer(0)

assert Integer(0b01010101) & Integer(0b11011011) == Integer(0b01010001)
assert Integer(0b01010101) & 0b11011011 == Integer(0b01010001)
assert 0b01010101 & Integer(0b11011011) == Integer(0b01010001)

assert -Integer(0b01010101) & Integer(0b11011011) == Integer(-0b01010101 & 0b11011011)
assert Integer(-0b01010101) & 0b11011011 == Integer(-0b01010101 & 0b11011011)
assert -0b01010101 & Integer(0b11011011) == Integer(-0b01010101 & 0b11011011)

assert Integer(0b01010101) & -Integer(0b11011011) == Integer(0b01010101 & -0b11011011)
assert Integer(0b01010101) & -0b11011011 == Integer(0b01010101 & -0b11011011)
assert 0b01010101 & Integer(-0b11011011) == Integer(0b01010101 & -0b11011011)

raises(TypeError, lambda: Integer(2) & 0.0)
raises(TypeError, lambda: 0.0 & Integer(2))


def test_xor_Integer():
assert Integer(0b01010101) ^ Integer(0b11111111) == Integer(0b10101010)
assert Integer(0b01010101) ^ 0b11111111 == Integer(0b10101010)
assert 0b01010101 ^ Integer(0b11111111) == Integer(0b10101010)

assert Integer(0b01010101) ^ Integer(0b11011011) == Integer(0b10001110)
assert Integer(0b01010101) ^ 0b11011011 == Integer(0b10001110)
assert 0b01010101 ^ Integer(0b11011011) == Integer(0b10001110)

assert -Integer(0b01010101) ^ Integer(0b11011011) == Integer(-0b01010101 ^ 0b11011011)
assert Integer(-0b01010101) ^ 0b11011011 == Integer(-0b01010101 ^ 0b11011011)
assert -0b01010101 ^ Integer(0b11011011) == Integer(-0b01010101 ^ 0b11011011)

assert Integer(0b01010101) ^ -Integer(0b11011011) == Integer(0b01010101 ^ -0b11011011)
assert Integer(0b01010101) ^ -0b11011011 == Integer(0b01010101 ^ -0b11011011)
assert 0b01010101 ^ Integer(-0b11011011) == Integer(0b01010101 ^ -0b11011011)

raises(TypeError, lambda: Integer(2) ^ 0.0)
raises(TypeError, lambda: 0.0 ^ Integer(2))


def test_or_Integer():
assert Integer(0b01010101) | Integer(0b10101010) == Integer(0b11111111)
assert Integer(0b01010101) | 0b10101010 == Integer(0b11111111)
assert 0b01010101 | Integer(0b10101010) == Integer(0b11111111)

assert Integer(0b01010101) | Integer(0b11011011) == Integer(0b11011111)
assert Integer(0b01010101) | 0b11011011 == Integer(0b11011111)
assert 0b01010101 | Integer(0b11011011) == Integer(0b11011111)

assert -Integer(0b01010101) | Integer(0b11011011) == Integer(-0b01010101 | 0b11011011)
assert Integer(-0b01010101) | 0b11011011 == Integer(-0b01010101 | 0b11011011)
assert -0b01010101 | Integer(0b11011011) == Integer(-0b01010101 | 0b11011011)

assert Integer(0b01010101) | -Integer(0b11011011) == Integer(0b01010101 | -0b11011011)
assert Integer(0b01010101) | -0b11011011 == Integer(0b01010101 | -0b11011011)
assert 0b01010101 | Integer(-0b11011011) == Integer(0b01010101 | -0b11011011)

raises(TypeError, lambda: Integer(2) | 0.0)
raises(TypeError, lambda: 0.0 | Integer(2))


def test_invert_Integer():
assert ~Integer(0b01010101) == Integer(-0b01010110)
assert ~Integer(0b01010101) == Integer(~0b01010101)
assert ~(~Integer(0b01010101)) == Integer(0b01010101)


def test_abs1():
assert Rational(1, 6) != Rational(-1, 6)
assert abs(Rational(1, 6)) == abs(Rational(-1, 6))
Expand Down Expand Up @@ -2037,6 +2158,10 @@ def test_abc():
assert(isinstance(y, nums.Rational))
z = numbers.Integer(3)
assert(isinstance(z, nums.Number))
assert(isinstance(z, numbers.Number))
assert(isinstance(z, nums.Rational))
assert(isinstance(z, numbers.Rational))
assert(isinstance(z, nums.Integral))

def test_floordiv():
assert S(2)//S.Half == 4

0 comments on commit 7c4d578

Please sign in to comment.