Skip to content

Commit

Permalink
Merge pull request #26606 from haru-44/__pow__
Browse files Browse the repository at this point in the history
Enhanced the efficiency of power calculations
  • Loading branch information
oscarbenjamin committed May 20, 2024
2 parents 8519e0f + 4d3b7f5 commit 3cc42c4
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 54 deletions.
22 changes: 12 additions & 10 deletions sympy/combinatorics/free_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,19 +487,21 @@ def __str__(self):

def __pow__(self, n):
n = as_int(n)
group = self.group
result = self.group.identity
if n == 0:
return group.identity

return result
if n < 0:
n = -n
return (self.inverse())**n

result = self
for i in range(n - 1):
result = result*self
# this method can be improved instead of just returning the
# multiplication of elements
x = self.inverse()
else:
x = self
while True:
if n % 2:
result *= x
n >>= 1
if not n:
break
x *= x
return result

def __mul__(self, other):
Expand Down
6 changes: 6 additions & 0 deletions sympy/combinatorics/tests/test_free_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ def test_FreeGroupElm__mul__pow__():
assert x*(x**-1*y*z*y**-1) == y*z*y**-1
assert x**2*(x**-2*y**-1*z**2*y) == y**-1*z**2*y

a = F.identity
for n in range(10):
assert a == x**n
assert a**-1 == x**-n
a *= x


def test_FreeGroupElm__len__():
assert len(x**5*y*x**2*y**-4*x) == 13
Expand Down
42 changes: 23 additions & 19 deletions sympy/holonomic/holonomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,20 +309,22 @@ def __truediv__(self, other):
def __pow__(self, n):
if n == 1:
return self
result = DifferentialOperator([self.parent.base.one], self.parent)
if n == 0:
return DifferentialOperator([self.parent.base.one], self.parent)

return result
# if self is `Dx`
if self.listofpoly == self.parent.derivative_operator.listofpoly:
sol = [self.parent.base.zero]*n + [self.parent.base.one]
return DifferentialOperator(sol, self.parent)

# the general case
if n % 2 == 1:
powreduce = self**(n - 1)
return powreduce * self
powreduce = self**(n // 2)
return powreduce * powreduce
x = self
while True:
if n % 2:
result *= x
n >>= 1
if not n:
break
x *= x
return result

def __str__(self):
listofpoly = self.listofpoly
Expand Down Expand Up @@ -1094,17 +1096,19 @@ def __pow__(self, n):
return HolonomicFunction(dd, self.x, self.x0, y0)
if n < 0:
raise NotHolonomicError("Negative Power on a Holonomic Function")
Dx = self.annihilator.parent.derivative_operator
result = HolonomicFunction(Dx, self.x, S.Zero, [S.One])
if n == 0:
Dx = self.annihilator.parent.derivative_operator
return HolonomicFunction(Dx, self.x, S.Zero, [S.One])
if n == 1:
return self
if n % 2 == 1:
powreduce = self**(n - 1)
return powreduce * self
if n % 2 == 0:
powreduce = self**(n / 2)
return powreduce * powreduce
return result
x = self
while True:
if n % 2:
result *= x
n >>= 1
if not n:
break
x *= x
return result

def degree(self):
"""
Expand Down
26 changes: 12 additions & 14 deletions sympy/holonomic/recurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,24 +253,22 @@ def __rsub__(self, other):
def __pow__(self, n):
if n == 1:
return self
result = RecurrenceOperator([self.parent.base.one], self.parent)
if n == 0:
return RecurrenceOperator([self.parent.base.one], self.parent)
return result
# if self is `Sn`
if self.listofpoly == self.parent.shift_operator.listofpoly:
sol = []
for i in range(0, n):
sol.append(self.parent.base.zero)
sol.append(self.parent.base.one)

sol = [self.parent.base.zero] * n + [self.parent.base.one]
return RecurrenceOperator(sol, self.parent)

else:
if n % 2 == 1:
powreduce = self**(n - 1)
return powreduce * self
elif n % 2 == 0:
powreduce = self**(n / 2)
return powreduce * powreduce
x = self
while True:
if n % 2:
result *= x
n >>= 1
if not n:
break
x *= x
return result

def __str__(self):
listofpoly = self.listofpoly
Expand Down
21 changes: 21 additions & 0 deletions sympy/holonomic/tests/test_holonomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ def test_HolonomicFunction_multiplication():
assert p*q == r


def test_HolonomicFunction_power():
x = symbols('x')
R, Dx = DifferentialOperators(ZZ.old_poly_ring(x), 'Dx')
p = HolonomicFunction(Dx+x+x*Dx**2, x)
a = HolonomicFunction(Dx, x)
for n in range(10):
assert a == p**n
a *= p


def test_addition_initial_condition():
x = symbols('x')
R, Dx = DifferentialOperators(QQ.old_poly_ring(x), 'Dx')
Expand Down Expand Up @@ -814,6 +824,7 @@ def test_expr_in_power():

assert h1 == h2


def test_DifferentialOperatorEqPoly():
x = symbols('x', integer=True)
R, Dx = DifferentialOperators(QQ.old_poly_ring(x), 'Dx')
Expand All @@ -828,3 +839,13 @@ def test_DifferentialOperatorEqPoly():

p2 = do2.listofpoly[0]
assert not do2 == p2


def test_DifferentialOperatorPow():
x = symbols('x', integer=True)
R, _ = DifferentialOperators(QQ.old_poly_ring(x), 'Dx')
do = DifferentialOperator([x**2, R.base.zero, R.base.zero], R)
a = DifferentialOperator([R.base.one], R)
for n in range(10):
assert a == do**n
a *= do
12 changes: 12 additions & 0 deletions sympy/holonomic/tests/test_recurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sympy.core.symbol import symbols
from sympy.polys.domains.rationalfield import QQ


def test_RecurrenceOperator():
n = symbols('n', integer=True)
R, Sn = RecurrenceOperators(QQ.old_poly_ring(n), 'Sn')
Expand All @@ -13,6 +14,7 @@ def test_RecurrenceOperator():
117*n**2 + 324*n + 324)*Sn**6
assert p == q


def test_RecurrenceOperatorEqPoly():
n = symbols('n', integer=True)
R, Sn = RecurrenceOperators(QQ.old_poly_ring(n), 'Sn')
Expand All @@ -27,3 +29,13 @@ def test_RecurrenceOperatorEqPoly():

d2 = rr2.listofpoly[0]
assert not rr2 == d2


def test_RecurrenceOperatorPow():
n = symbols('n', integer=True)
R, _ = RecurrenceOperators(QQ.old_poly_ring(n), 'Sn')
rr = RecurrenceOperator([n**2, 0, 0], R)
a = RecurrenceOperator([R.base.one], R)
for m in range(10):
assert a == rr**m
a *= rr
13 changes: 2 additions & 11 deletions sympy/polys/monomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,18 +593,9 @@ def __truediv__(self, other):

def __pow__(self, other):
n = int(other)

if not n:
return self.rebuild([0]*len(self))
elif n > 0:
exponents = self.exponents

for i in range(1, n):
exponents = monomial_mul(exponents, self.exponents)

return self.rebuild(exponents)
else:
if n < 0:
raise ValueError("a non-negative integer expected, got %s" % other)
return self.rebuild(monomial_pow(self.exponents, n))

def gcd(self, other):
"""Greatest common divisor of monomials. """
Expand Down
4 changes: 4 additions & 0 deletions sympy/polys/tests/test_monomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def test_Monomial():
assert m**1 == m
assert m**2 == Monomial((6, 8, 2))
assert m**3 == Monomial((9, 12, 3))
_a = Monomial((0, 0, 0))
for n in range(10):
assert _a == m**n
_a *= m

raises(ExactQuotientFailed, lambda: m/Monomial((5, 2, 0)))

Expand Down

0 comments on commit 3cc42c4

Please sign in to comment.